/*
 * Decompiled with CFR 0.152.
 */
package chat.octet.model;

import chat.octet.model.LlamaService;
import chat.octet.model.TokenDecoder;
import chat.octet.model.beans.CompletionResult;
import chat.octet.model.beans.Status;
import chat.octet.model.beans.Token;
import chat.octet.model.enums.FinishReason;
import chat.octet.model.exceptions.DecodeException;
import chat.octet.model.exceptions.GenerationException;
import chat.octet.model.parameters.GenerateParameter;
import java.nio.charset.StandardCharsets;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Spliterator;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Generator
implements Iterable<Token> {
    private static final Logger log = LoggerFactory.getLogger(Generator.class);
    private Inference inference;
    private Status chatStatus;

    public Generator(GenerateParameter generateParams, String prompt, Status chatStatus) {
        this.chatStatus = chatStatus;
        this.inference = new Inference(generateParams, prompt, chatStatus);
    }

    public Generator(GenerateParameter generateParams, String prompt) {
        new Generator(generateParams, prompt, null);
    }

    @Override
    @Nonnull
    public Iterator<Token> iterator() {
        return this.inference;
    }

    @Override
    public Spliterator<Token> spliterator() {
        throw new RuntimeException("Unsupported operation.");
    }

    public void output() {
        try {
            for (Token token : this) {
                System.out.print(token.getText());
            }
        }
        catch (Exception e) {
            throw new GenerationException("Generate next token error ", e);
        }
        finally {
            if (this.chatStatus != null) {
                this.chatStatus.copyToStatus(this.inference.getStatus());
            } else {
                this.inference.clearCache();
            }
        }
    }

    public CompletionResult result() {
        StringBuilder builder = new StringBuilder();
        FinishReason finishReason = FinishReason.UNKNOWN;
        while (this.inference.hasNext()) {
            Token token = this.inference.next();
            builder.append(token.getText());
            finishReason = token.getFinishReason();
        }
        return CompletionResult.builder().content(builder.toString()).finishReason(finishReason).build();
    }

    private static class Inference
    implements Iterator<Token> {
        private final GenerateParameter generateParams;
        private final Status status;
        private final byte[] multiByteTokenBuffer;
        private int multiByteTokenLength;
        private int multiByteTokenIndex;
        private boolean finished = false;
        private final int maxNewTokenSize;
        private final int contextSize;

        protected Status getStatus() {
            return this.status;
        }

        protected Inference(GenerateParameter generateParams, String prompt, Status srcStatus) {
            boolean status;
            int[] nArray;
            this.generateParams = generateParams;
            this.multiByteTokenBuffer = new byte[8];
            this.contextSize = LlamaService.getContextSize();
            Status status2 = this.status = srcStatus == null ? new Status() : new Status(srcStatus);
            if (StringUtils.isNotBlank((CharSequence)prompt)) {
                nArray = LlamaService.tokenize(prompt, true, true);
            } else {
                int[] nArray2 = new int[1];
                nArray = nArray2;
                nArray2[0] = LlamaService.getTokenBOS();
            }
            int[] tokens = nArray;
            if (tokens.length >= this.contextSize) {
                throw new IllegalArgumentException(MessageFormat.format("Requested tokens ({0}) exceed context window of {1}.", tokens.length, this.contextSize));
            }
            if (generateParams.isVerbosePrompt()) {
                log.info("Print prompt text:\n{}", (Object)prompt);
            }
            this.status.appendTokens(tokens);
            int n = this.maxNewTokenSize = generateParams.getMaxNewTokenSize() <= 0 ? this.contextSize - this.status.getInputLength() : generateParams.getMaxNewTokenSize();
            if (StringUtils.isNotBlank((CharSequence)generateParams.getGrammarRules()) && !(status = LlamaService.loadLlamaGrammar(generateParams.getGrammarRules()))) {
                log.error("Grammar rule parsing failed, Please check the grammar rule format.");
            }
            log.debug("Generate starting, input token size: {}, past token size: {}.", (Object)tokens.length, (Object)this.status.getPastTokenSize());
            this.decodePrompt();
        }

        protected Inference(GenerateParameter generateParams, String text) {
            this(generateParams, text, null);
        }

        private void decodePrompt() {
            int decodeStatus = LlamaService.batchDecode(this.status.getId(), this.status.getInputIds(), this.status.getInputLength(), this.status.getPastTokenSize());
            if (decodeStatus != 0) {
                throw new DecodeException(MessageFormat.format("Failed to decode, return code: {0}.", decodeStatus));
            }
            int size = this.status.getInputLength() - this.status.getPastTokenSize();
            this.status.addPastTokensSize(size);
            log.debug("Batch decode prompt completed, decode token size: {}, sequence id: {}.", (Object)size, (Object)this.status.getId());
        }

        private boolean breakOrContinue(Token token, float[] logits) {
            boolean matched;
            if (token.getId() == LlamaService.getTokenEOS()) {
                token.updateFinishReason(FinishReason.FINISHED);
                return true;
            }
            if (this.generateParams.getStoppingCriteriaList() != null && (matched = this.generateParams.getStoppingCriteriaList().criteria(this.status.getInputIds(), logits, new Object[0]))) {
                token.updateFinishReason(FinishReason.STOP);
                return true;
            }
            if (this.status.getInputLength() >= this.contextSize) {
                token.updateFinishReason(FinishReason.TRUNCATED);
                log.warn("Context size has been exceeded. Truncate and reset the context cache, sequence id: {}.", (Object)this.status.getId());
                return true;
            }
            if (this.status.getGenerateTokens().size() >= this.maxNewTokenSize) {
                token.updateFinishReason(FinishReason.LENGTH);
                return true;
            }
            return false;
        }

        private String tokenToText(int token) {
            byte[] buffer = new byte[64];
            int length = LlamaService.tokenToPiece(token, buffer, buffer.length);
            byte code = buffer[0];
            if (length == 1 && !Character.isValidCodePoint(code)) {
                if (this.multiByteTokenLength == 0) {
                    this.multiByteTokenLength = TokenDecoder.getUtf8ByteLength(code);
                }
                this.multiByteTokenBuffer[this.multiByteTokenIndex] = code;
                ++this.multiByteTokenIndex;
                if (this.multiByteTokenIndex == this.multiByteTokenLength) {
                    String text = new String(this.multiByteTokenBuffer, 0, this.multiByteTokenLength, StandardCharsets.UTF_8);
                    this.multiByteTokenIndex = 0;
                    this.multiByteTokenLength = 0;
                    Arrays.fill(this.multiByteTokenBuffer, (byte)0);
                    return text;
                }
                return "";
            }
            return new String(buffer, 0, length, StandardCharsets.UTF_8);
        }

        @Override
        public boolean hasNext() {
            return !this.finished;
        }

        @Override
        public Token next() {
            float[] logits = LlamaService.getLogits(this.status.getLogitsIndex());
            if (this.generateParams.getLogitsProcessorList() != null) {
                logits = this.generateParams.getLogitsProcessorList().processor(this.status.getInputIds(), logits, new Object[0]);
            }
            int[] lastTokens = null;
            if (this.generateParams.getLastTokensSize() != 0) {
                int startIndex = Math.max(0, this.status.getInputLength() - this.generateParams.getLastTokensSize());
                lastTokens = this.status.subInputIds(startIndex);
            }
            int tokenId = LlamaService.sampling(logits, lastTokens, this.generateParams.getLastTokensSize(), this.generateParams.getRepeatPenalty(), this.generateParams.getFrequencyPenalty(), this.generateParams.getPresencePenalty(), this.generateParams.isPenalizeNl(), this.generateParams.getMirostatMode().ordinal(), this.generateParams.getMirostatTAU(), this.generateParams.getMirostatETA(), this.generateParams.getTemperature(), this.generateParams.getTopK(), this.generateParams.getTopP(), this.generateParams.getTsf(), this.generateParams.getTypical(), this.status.getId(), this.status.getPastTokenSize());
            Token token = new Token(tokenId, LlamaService.getLlamaTokenType(tokenId), this.tokenToText(tokenId));
            this.status.appendNextToken(token);
            this.finished = this.breakOrContinue(token, logits);
            return token;
        }

        public void clearCache() {
            this.status.reset();
            log.debug("Cache clear completed, sequence id: {}.", (Object)this.status.getId());
        }
    }
}

