/*
 * Decompiled with CFR 0.152.
 */
package chat.octet.model;

import chat.octet.model.Generator;
import chat.octet.model.LlamaService;
import chat.octet.model.TokenDecoder;
import chat.octet.model.beans.CompletionResult;
import chat.octet.model.beans.LlamaContextParams;
import chat.octet.model.beans.LlamaModelParams;
import chat.octet.model.beans.Status;
import chat.octet.model.components.criteria.impl.StoppingWordCriteria;
import chat.octet.model.components.processor.impl.CustomBiasLogitsProcessor;
import chat.octet.model.exceptions.ModelException;
import chat.octet.model.parameters.GenerateParameter;
import chat.octet.model.parameters.ModelParameter;
import chat.octet.model.utils.ChatFormatter;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.google.common.io.Resources;
import java.io.File;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Model
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(Model.class);
    private final ModelParameter modelParams;
    private final String modelName;
    private final String modelType;
    private final ChatFormatter chatFormatter;
    private final Map<String, Status> chatStatus = Maps.newConcurrentMap();

    public Model(String modelPath) {
        this(ModelParameter.builder().modelPath(modelPath).build());
    }

    public Model(ModelParameter modelParams) {
        String chatTemplateStr;
        Preconditions.checkNotNull((Object)modelParams, (Object)"Model parameters cannot be null");
        Preconditions.checkNotNull((Object)modelParams.getModelPath(), (Object)"Model file path cannot be null");
        if (!Files.exists(new File(modelParams.getModelPath()).toPath(), new LinkOption[0])) {
            throw new ModelException("Model file is not exists, please check the file path");
        }
        this.modelParams = modelParams;
        LlamaService.llamaBackendInit();
        LlamaService.llamaNumaInit(modelParams.getNumaStrategy());
        LlamaModelParams llamaModelParams = this.getLlamaModelParameters(modelParams);
        LlamaService.loadLlamaModelFromFile(modelParams.getModelPath(), llamaModelParams);
        LlamaContextParams llamaContextParams = this.getLlamaContextParameters(modelParams);
        LlamaService.createNewContextWithModel(llamaContextParams);
        if (StringUtils.isNotBlank((CharSequence)modelParams.getLoraPath())) {
            if (!Files.exists(new File(modelParams.getLoraPath()).toPath(), new LinkOption[0])) {
                throw new ModelException("Lora model file is not exists, please check the file path");
            }
            int status = LlamaService.loadLoraModelFromFile(modelParams.getLoraPath(), modelParams.getLoraScale(), modelParams.getLoraBase(), modelParams.getThreads());
            if (status != 0) {
                throw new ModelException(String.format("Failed to apply LoRA from lora path: %s to base path: %s", modelParams.getLoraPath(), modelParams.getLoraBase()));
            }
        }
        this.modelName = LlamaService.llamaModelMeta("general.name");
        this.modelType = LlamaService.llamaModelMeta("general.architecture");
        try {
            chatTemplateStr = Resources.toString((URL)Resources.getResource((String)("chat-templates/" + this.modelType + ".tmpl")), (Charset)Charsets.UTF_8);
        }
        catch (Exception e) {
            chatTemplateStr = LlamaService.llamaModelMeta("tokenizer.chat_template");
            log.error("Failed to load local chat template, attempting to load chat template from the model.", (Throwable)e);
        }
        if (StringUtils.isBlank((CharSequence)chatTemplateStr)) {
            this.chatFormatter = new ChatFormatter();
            log.warn("Chat template is not found, use default template.");
        } else {
            this.chatFormatter = new ChatFormatter(chatTemplateStr, TokenDecoder.decodeToken(true, LlamaService.getTokenBOS()), TokenDecoder.decodeToken(true, LlamaService.getTokenEOS()));
        }
        if (modelParams.isVerbose()) {
            log.info("system info: {}", (Object)LlamaService.getSystemInfo());
        }
        log.info("model parameters: {}", (Object)modelParams);
    }

    private LlamaModelParams getLlamaModelParameters(ModelParameter modelParams) {
        boolean mlock;
        boolean mmap;
        LlamaModelParams llamaModelParams = LlamaService.getLlamaModelDefaultParams();
        llamaModelParams.gpuLayers = modelParams.getGpuLayers();
        llamaModelParams.splitMode = modelParams.getSplitMode();
        llamaModelParams.mainGpu = modelParams.getMainGpu();
        if (modelParams.getTensorSplit() != null) {
            llamaModelParams.tensorSplit = modelParams.getTensorSplit();
        }
        llamaModelParams.vocabOnly = modelParams.isVocabOnly();
        boolean bl = mmap = StringUtils.isBlank((CharSequence)modelParams.getLoraPath()) && modelParams.isMmap();
        if (mmap && LlamaService.isMmapSupported()) {
            llamaModelParams.mmap = true;
        }
        if ((mlock = modelParams.isMlock()) && LlamaService.isMlockSupported()) {
            llamaModelParams.mlock = true;
        }
        llamaModelParams.checkTensors = modelParams.isCheckTensors();
        return llamaModelParams;
    }

    private LlamaContextParams getLlamaContextParameters(ModelParameter modelParams) {
        LlamaContextParams llamaContextParams = LlamaService.getLlamaContextDefaultParams();
        llamaContextParams.seed = modelParams.getSeed();
        llamaContextParams.ctx = modelParams.getContextSize();
        llamaContextParams.batch = modelParams.getBatchSize();
        llamaContextParams.ubatch = modelParams.getUbatch();
        llamaContextParams.seqMax = modelParams.getSeqMax();
        llamaContextParams.threads = modelParams.getThreads();
        llamaContextParams.threadsBatch = modelParams.getThreadsBatch() == -1 ? modelParams.getThreads() : modelParams.getThreadsBatch();
        llamaContextParams.ropeScalingType = modelParams.getRopeScalingType();
        llamaContextParams.poolingType = modelParams.getPoolingType();
        llamaContextParams.yarnExtFactor = modelParams.getYarnExtFactor();
        llamaContextParams.yarnAttnFactor = modelParams.getYarnAttnFactor();
        llamaContextParams.yarnBetaFast = modelParams.getYarnBetaFast();
        llamaContextParams.yarnBetaSlow = modelParams.getYarnBetaSlow();
        llamaContextParams.yarnOrigCtx = modelParams.getYarnOrigCtx();
        llamaContextParams.defragThold = modelParams.getDefragThold();
        llamaContextParams.ropeFreqBase = modelParams.getRopeFreqBase();
        llamaContextParams.ropeFreqScale = modelParams.getRopeFreqScale();
        llamaContextParams.logitsAll = modelParams.isLogitsAll();
        llamaContextParams.embedding = modelParams.isEmbedding();
        llamaContextParams.offloadKqv = modelParams.isOffloadKqv();
        llamaContextParams.flashAttn = modelParams.isFlashAttn();
        return llamaContextParams;
    }

    public void removeChatStatus(String session) {
        String id = "";
        if (session.contains(":")) {
            id = session.split(":")[1];
        }
        for (String key : this.chatStatus.keySet()) {
            if (!key.equals(session) && !key.endsWith(id)) continue;
            Status status = this.chatStatus.remove(key);
            if (status != null) {
                status.reset();
            }
            log.info("Removed chat session, session: {}.", (Object)key);
        }
    }

    public void removeAllChatStatus() {
        int size = this.chatStatus.size();
        if (size > 0) {
            this.chatStatus.keySet().forEach(this::removeChatStatus);
            log.info("Removed all chat sessions, size: {}.", (Object)size);
        }
    }

    public CompletionResult completions(String text) {
        return this.completions(GenerateParameter.builder().build(), text);
    }

    public CompletionResult completions(GenerateParameter generateParams, String text) {
        return this.generate(generateParams, text).result();
    }

    public Generator generate(String text) {
        return this.generate(GenerateParameter.builder().build(), text);
    }

    public Generator generate(GenerateParameter generateParams, String text) {
        Preconditions.checkNotNull((Object)generateParams, (Object)"Generate parameter cannot be null");
        Preconditions.checkNotNull((Object)text, (Object)"Text cannot be null");
        if (generateParams.getLogitBias() != null && !generateParams.getLogitBias().isEmpty()) {
            generateParams.getLogitsProcessorList().add(new CustomBiasLogitsProcessor(generateParams.getLogitBias(), LlamaService.getVocabSize()));
        }
        if (generateParams.getStoppingWord() != null) {
            generateParams.getStoppingCriteriaList().add(new StoppingWordCriteria(generateParams.getStoppingWord()));
        }
        return new Generator(generateParams, text);
    }

    public CompletionResult chatCompletions(String question) {
        return this.chatCompletions(GenerateParameter.builder().build(), null, question);
    }

    public CompletionResult chatCompletions(GenerateParameter generateParams, String question) {
        return this.chatCompletions(generateParams, null, question);
    }

    public CompletionResult chatCompletions(GenerateParameter generateParams, String system, String question) {
        return this.chat(generateParams, system, question).result();
    }

    public Generator chat(String question) {
        return this.chat(GenerateParameter.builder().build(), null, question);
    }

    public Generator chat(String system, String question) {
        return this.chat(GenerateParameter.builder().build(), system, question);
    }

    public Generator chat(GenerateParameter generateParams, String question) {
        return this.chat(generateParams, null, question);
    }

    public Generator chat(GenerateParameter generateParams, String system, String question) {
        Status userStatus;
        String key;
        boolean exists;
        Preconditions.checkNotNull((Object)generateParams, (Object)"Generate parameter cannot be null");
        Preconditions.checkNotNull((Object)question, (Object)"Question cannot be null");
        Preconditions.checkNotNull((Object)generateParams.getUser(), (Object)"User id cannot be null");
        if (generateParams.getLogitBias() != null && !generateParams.getLogitBias().isEmpty()) {
            generateParams.getLogitsProcessorList().add(new CustomBiasLogitsProcessor(generateParams.getLogitBias(), LlamaService.getVocabSize()));
        }
        if (generateParams.getStoppingWord() != null) {
            generateParams.getStoppingCriteriaList().add(new StoppingWordCriteria(generateParams.getStoppingWord()));
        }
        if (!(exists = this.chatStatus.containsKey(key = StringUtils.isBlank((CharSequence)generateParams.getSession()) ? generateParams.getUser() : generateParams.getUser() + ":" + generateParams.getSession()))) {
            userStatus = new Status();
            this.chatStatus.put(key, userStatus);
            log.debug("Create new chat session, session: {} id: {}, chat session cache size: {}.", new Object[]{key, userStatus.getId(), this.chatStatus.size()});
        }
        userStatus = this.chatStatus.get(key);
        if (StringUtils.isNotBlank((CharSequence)system) && system.equals(userStatus.getInitialSystemPrompt())) {
            system = null;
        }
        if (StringUtils.isNotBlank((CharSequence)system) && StringUtils.isBlank((CharSequence)userStatus.getInitialSystemPrompt())) {
            userStatus.setInitialSystemPrompt(system);
        }
        String prompt = this.chatFormatter.format(system, question);
        return new Generator(generateParams, prompt, userStatus);
    }

    public void metrics() {
        if (this.modelParams.isVerbose()) {
            log.info("Metrics: {}", (Object)LlamaService.getSamplingMetrics(true).toString());
        }
    }

    @Override
    public void close() {
        this.removeAllChatStatus();
        LlamaService.release();
        LlamaService.llamaBackendFree();
        log.info("Closed model and context resources.");
    }

    public String toString() {
        return "LlamaModel (modelParams=" + this.modelParams + ')';
    }

    public ModelParameter getModelParams() {
        return this.modelParams;
    }

    public String getModelName() {
        return this.modelName;
    }

    public String getModelType() {
        return this.modelType;
    }
}

