/*
 * Decompiled with CFR 0.152.
 */
package chat.octet.model;

import chat.octet.model.Generator;
import chat.octet.model.LlamaService;
import chat.octet.model.beans.CompletionResult;
import chat.octet.model.beans.LlamaContextParams;
import chat.octet.model.beans.LlamaModelParams;
import chat.octet.model.beans.Status;
import chat.octet.model.enums.ModelType;
import chat.octet.model.exceptions.ModelException;
import chat.octet.model.parameters.GenerateParameter;
import chat.octet.model.parameters.ModelParameter;
import chat.octet.model.utils.PromptBuilder;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Model
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(Model.class);
    private final ModelParameter modelParams;
    private final String modelName;
    private final String modelType;
    private final int lastTokensSize;
    private final Map<String, Status> chatStatus = Maps.newConcurrentMap();

    public Model(String modelPath) {
        this(ModelParameter.builder().modelPath(modelPath).build());
    }

    public Model(ModelParameter modelParams) {
        Preconditions.checkNotNull((Object)modelParams, (Object)"Model parameters cannot be null");
        Preconditions.checkNotNull((Object)modelParams.getModelPath(), (Object)"Model file path cannot be null");
        if (!Files.exists(new File(modelParams.getModelPath()).toPath(), new LinkOption[0])) {
            throw new ModelException("Model file is not exists, please check the file path");
        }
        this.modelParams = modelParams;
        this.modelName = modelParams.getModelName();
        this.modelType = modelParams.getModelType();
        Preconditions.checkNotNull((Object)this.modelType, (Object)"Model type cannot be null");
        this.lastTokensSize = modelParams.getLastNTokensSize() < 0 ? LlamaService.getContextSize() : modelParams.getLastNTokensSize();
        LlamaModelParams llamaModelParams = this.getLlamaModelParameters(modelParams);
        LlamaService.loadLlamaModelFromFile(modelParams.getModelPath(), llamaModelParams);
        LlamaContextParams llamaContextParams = this.getLlamaContextParameters(modelParams);
        LlamaService.createNewContextWithModel(llamaContextParams);
        if (StringUtils.isNotBlank((CharSequence)modelParams.getLoraPath())) {
            if (!Files.exists(new File(modelParams.getLoraPath()).toPath(), new LinkOption[0])) {
                throw new ModelException("Lora model file is not exists, please check the file path");
            }
            int status = LlamaService.loadLoraModelFromFile(modelParams.getLoraPath(), modelParams.getLoraScale(), modelParams.getLoraBase(), modelParams.getThreads());
            if (status != 0) {
                throw new ModelException(String.format("Failed to apply LoRA from lora path: %s to base path: %s", modelParams.getLoraPath(), modelParams.getLoraBase()));
            }
        }
        if (modelParams.isVerbose()) {
            log.info("system info: {}", (Object)LlamaService.getSystemInfo());
        }
        log.info("model parameters: {}", (Object)modelParams);
    }

    private LlamaModelParams getLlamaModelParameters(ModelParameter modelParams) {
        boolean mlock;
        boolean mmap;
        LlamaModelParams llamaModelParams = LlamaService.getLlamaModelDefaultParams();
        llamaModelParams.gpuLayers = modelParams.getGpuLayers();
        llamaModelParams.vocabOnly = modelParams.isVocabOnly();
        boolean bl = mmap = StringUtils.isBlank((CharSequence)modelParams.getLoraPath()) && modelParams.isMmap();
        if (mmap && LlamaService.isMmapSupported()) {
            llamaModelParams.mmap = true;
        }
        if ((mlock = modelParams.isMlock()) && LlamaService.isMlockSupported()) {
            llamaModelParams.mlock = true;
        }
        if (modelParams.getMainGpu() != null) {
            llamaModelParams.mainGpu = modelParams.getMainGpu();
        }
        if (modelParams.getTensorSplit() != null) {
            llamaModelParams.tensorSplit = modelParams.getTensorSplit();
        }
        return llamaModelParams;
    }

    private LlamaContextParams getLlamaContextParameters(ModelParameter modelParams) {
        LlamaContextParams llamaContextParams = LlamaService.getLlamaContextDefaultParams();
        llamaContextParams.seed = modelParams.getSeed();
        llamaContextParams.ctx = modelParams.getContextSize();
        llamaContextParams.batch = modelParams.getBatchSize();
        llamaContextParams.threads = modelParams.getThreads();
        llamaContextParams.threadsBatch = modelParams.getThreadsBatch() == -1 ? modelParams.getThreads() : modelParams.getThreadsBatch();
        llamaContextParams.ropeScalingType = modelParams.getRopeScalingType();
        llamaContextParams.yarnExtFactor = modelParams.getYarnExtFactor();
        llamaContextParams.yarnAttnFactor = modelParams.getYarnAttnFactor();
        llamaContextParams.yarnBetaFast = modelParams.getYarnBetaFast();
        llamaContextParams.yarnBetaSlow = modelParams.getYarnBetaSlow();
        llamaContextParams.yarnOrigCtx = modelParams.getYarnOrigCtx();
        llamaContextParams.ropeFreqBase = modelParams.getRopeFreqBase();
        llamaContextParams.ropeFreqScale = modelParams.getRopeFreqScale();
        llamaContextParams.mulMatQ = modelParams.isMulMatQ();
        llamaContextParams.f16KV = modelParams.isF16KV();
        llamaContextParams.logitsAll = modelParams.isLogitsAll();
        llamaContextParams.embedding = modelParams.isEmbedding();
        return llamaContextParams;
    }

    public void removeChatStatus(String user) {
        boolean exists = this.chatStatus.containsKey(user);
        if (exists) {
            Status status = this.chatStatus.remove(user);
            if (status != null) {
                status.reset();
            }
            log.info("Removed chat session, User: {}.", (Object)user);
        }
    }

    public void removeAllChatStatus() {
        int size = this.chatStatus.size();
        if (size > 0) {
            this.chatStatus.keySet().forEach(this::removeChatStatus);
            log.info("Removed all chat sessions, size: {}.", (Object)size);
        }
    }

    public CompletionResult completions(String text) {
        return this.completions(GenerateParameter.builder().build(), text);
    }

    public CompletionResult completions(GenerateParameter generateParams, String text) {
        return this.generate(generateParams, text).result();
    }

    public Generator generate(String text) {
        return this.generate(GenerateParameter.builder().build(), text);
    }

    public Generator generate(GenerateParameter generateParams, String text) {
        Preconditions.checkNotNull((Object)generateParams, (Object)"Generate parameter cannot be null");
        Preconditions.checkNotNull((Object)text, (Object)"Text cannot be null");
        generateParams.setLastTokensSize(this.lastTokensSize);
        return new Generator(generateParams, text);
    }

    public CompletionResult chatCompletions(String question) {
        return this.chatCompletions(GenerateParameter.builder().build(), null, question);
    }

    public CompletionResult chatCompletions(GenerateParameter generateParams, String question) {
        return this.chatCompletions(generateParams, null, question);
    }

    public CompletionResult chatCompletions(GenerateParameter generateParams, String system, String question) {
        return this.chat(generateParams, system, question).result();
    }

    public Generator chat(String question) {
        return this.chat(GenerateParameter.builder().build(), null, question);
    }

    public Generator chat(String system, String question) {
        return this.chat(GenerateParameter.builder().build(), system, question);
    }

    public Generator chat(GenerateParameter generateParams, String question) {
        return this.chat(generateParams, null, question);
    }

    public Generator chat(GenerateParameter generateParams, String system, String question) {
        Status userStatus;
        Preconditions.checkNotNull((Object)generateParams, (Object)"Generate parameter cannot be null");
        Preconditions.checkNotNull((Object)question, (Object)"Question cannot be null");
        Preconditions.checkNotNull((Object)generateParams.getUser(), (Object)"User id cannot be null");
        generateParams.setLastTokensSize(this.lastTokensSize);
        boolean exists = this.chatStatus.containsKey(generateParams.getUser());
        if (!exists) {
            userStatus = new Status();
            this.chatStatus.put(generateParams.getUser(), userStatus);
            log.debug("Create new chat session, User: {} id: {}, chat session cache size: {}.", new Object[]{generateParams.getUser(), userStatus.getId(), this.chatStatus.size()});
        }
        userStatus = this.chatStatus.get(generateParams.getUser());
        if (StringUtils.isNotBlank((CharSequence)system) && system.equals(userStatus.getInitialSystemPrompt())) {
            system = null;
        }
        if (StringUtils.isNotBlank((CharSequence)system) && StringUtils.isBlank((CharSequence)userStatus.getInitialSystemPrompt())) {
            userStatus.setInitialSystemPrompt(system);
        }
        String prompt = PromptBuilder.format(ModelType.valueOf(this.modelType.toUpperCase()), system, question);
        return new Generator(generateParams, prompt, userStatus);
    }

    public void metrics() {
        if (this.modelParams.isVerbose()) {
            log.info("Metrics: {}", (Object)LlamaService.getSamplingMetrics(true).toString());
        }
    }

    @Override
    public void close() {
        this.removeAllChatStatus();
        LlamaService.release();
        LlamaService.llamaBackendFree();
    }

    public String toString() {
        return "LlamaModel (modelParams=" + this.modelParams + ')';
    }

    public ModelParameter getModelParams() {
        return this.modelParams;
    }

    public String getModelName() {
        return this.modelName;
    }

    public String getModelType() {
        return this.modelType;
    }
}

