/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_embedding;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.ZipUtils;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.algorithms.text_embedding.HuggingfaceTextEmbeddingTranslatorFactory;
import org.opensearch.ml.engine.algorithms.text_embedding.ONNXSentenceTransformerTextEmbeddingTranslator;
import org.opensearch.ml.engine.algorithms.text_embedding.SentenceTransformerTextEmbeddingTranslator;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.utils.FileUtils;

@Function(value=FunctionName.TEXT_EMBEDDING)
public class TextEmbeddingModel
implements Predictable {
    @Generated
    private static final Logger log = LogManager.getLogger(TextEmbeddingModel.class);
    public static final String SENTENCE_EMBEDDING = "sentence_embedding";
    public static final String MODEL_ZIP_FILE = "model_zip_file";
    public static final String MODEL_HELPER = "model_helper";
    public static final String ML_ENGINE = "ml_engine";
    private ModelHelper modelHelper;
    private MLEngine mlEngine;
    private String modelId;
    private Predictor<Input, Output>[] predictors;
    private ZooModel[] models;
    private Device[] devices;
    private AtomicInteger nextDevice = new AtomicInteger(0);

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        throw new MLException("model not loaded");
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        if (this.modelHelper == null || this.modelId == null) {
            throw new MLException("model not loaded");
        }
        return this.predictTextEmbedding(this.modelId, mlInput.getInputDataset());
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params) {
        String engine = model.getModelFormat() == MLModelFormat.TORCH_SCRIPT ? "PyTorch" : "OnnxRuntime";
        File modelZipFile = (File)params.get(MODEL_ZIP_FILE);
        this.modelHelper = (ModelHelper)params.get(MODEL_HELPER);
        this.mlEngine = (MLEngine)params.get(ML_ENGINE);
        if (modelZipFile == null) {
            throw new IllegalArgumentException("model file is null");
        }
        if (this.modelHelper == null) {
            throw new IllegalArgumentException("model helper is null");
        }
        if (this.mlEngine == null) {
            throw new IllegalArgumentException("ML engine is null");
        }
        this.modelId = model.getModelId();
        if (this.modelId == null) {
            throw new IllegalArgumentException("model id is null");
        }
        this.loadTextEmbeddingModel(modelZipFile, this.modelId, model.getName(), model.getAlgorithm(), model.getVersion(), model.getModelConfig(), engine);
    }

    @Override
    public void close() {
        if (this.modelHelper != null && this.modelId != null) {
            this.modelHelper.deleteFileCache(this.modelId);
            if (this.predictors != null) {
                this.closePredictors(this.predictors);
                this.predictors = null;
            }
            if (this.models != null) {
                this.closeModels(this.models);
                this.models = null;
            }
        }
    }

    protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String modelName, FunctionName functionName, String version, MLModelConfig modelConfig, String engine) {
        try {
            if (FunctionName.TEXT_EMBEDDING != functionName) {
                throw new IllegalArgumentException("wrong function name");
            }
            if (!"PyTorch".equals(engine) && !"OnnxRuntime".equals(engine)) {
                throw new IllegalArgumentException("unsupported engine");
            }
            ArrayList predictorList = new ArrayList();
            ArrayList modelList = new ArrayList();
            AccessController.doPrivileged(() -> {
                ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
                try {
                    System.setProperty("PYTORCH_PRECXX11", "true");
                    System.setProperty("DJL_CACHE_DIR", this.mlEngine.getDjlCachePath().toAbsolutePath().toString());
                    System.setProperty("java.library.path", this.mlEngine.getDjlCachePath().toAbsolutePath().toString());
                    System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
                    System.setProperty("ai.djl.pytorch.num_threads", "1");
                    Thread.currentThread().setContextClassLoader(Model.class.getClassLoader());
                    Path modelPath = this.mlEngine.getModelCachePath(modelId, modelName, version);
                    File pathFile = new File(modelPath.toUri());
                    if (pathFile.exists()) {
                        org.apache.commons.io.FileUtils.deleteDirectory((File)pathFile);
                    }
                    try (FileInputStream fileInputStream = new FileInputStream(modelZipFile);){
                        ZipUtils.unzip((InputStream)fileInputStream, (Path)modelPath);
                    }
                    boolean findModelFile = false;
                    for (File file : pathFile.listFiles()) {
                        String name = file.getName();
                        if (!name.endsWith(".pt") && !name.endsWith(".onnx")) continue;
                        if (findModelFile) {
                            throw new IllegalArgumentException("found multiple models");
                        }
                        findModelFile = true;
                        int dotIndex = name.lastIndexOf(".");
                        String suffix = name.substring(dotIndex);
                        if (modelName.equals(name.substring(0, dotIndex))) continue;
                        file.renameTo(new File(modelPath.resolve(modelName + suffix).toUri()));
                    }
                    this.devices = Engine.getEngine((String)engine).getDevices();
                    for (int i = 0; i < this.devices.length; ++i) {
                        log.debug("load model {} on device {}: {}", (Object)modelId, (Object)i, (Object)this.devices[i]);
                        HashMap arguments = new HashMap();
                        Criteria.Builder criteriaBuilder = Criteria.builder().setTypes(Input.class, Output.class).optApplication(Application.UNDEFINED).optArguments(arguments).optEngine(engine).optDevice(this.devices[i]).optModelPath(modelPath);
                        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
                        TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType();
                        if ("OnnxRuntime".equals(engine)) {
                            criteriaBuilder.optTranslator((Translator)new ONNXSentenceTransformerTextEmbeddingTranslator());
                        } else if (transformersType == TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) {
                            criteriaBuilder.optTranslator((Translator)new SentenceTransformerTextEmbeddingTranslator());
                        } else {
                            criteriaBuilder.optTranslatorFactory((TranslatorFactory)new HuggingfaceTextEmbeddingTranslatorFactory());
                        }
                        Criteria criteria = criteriaBuilder.build();
                        ZooModel model = criteria.loadModel();
                        Predictor predictor = model.newPredictor();
                        predictorList.add(predictor);
                        modelList.add(model);
                        Input input = new Input();
                        input.add("warm up sentence");
                        predictor.predict((Object)input);
                    }
                    if (predictorList.size() > 0) {
                        this.predictors = predictorList.toArray(new Predictor[0]);
                        predictorList.clear();
                    }
                    if (modelList.size() > 0) {
                        this.models = modelList.toArray(new ZooModel[0]);
                        modelList.clear();
                    }
                    log.info("Load model {} successfully on {} devices", (Object)modelId, (Object)this.devices.length);
                    Void void_ = null;
                    return void_;
                }
                catch (Exception e) {
                    String errorMessage = "Failed to load model " + modelId;
                    log.error(errorMessage, (Throwable)e);
                    this.close();
                    if (predictorList.size() > 0) {
                        this.closePredictors(predictorList.toArray(new Predictor[0]));
                        predictorList.clear();
                    }
                    if (modelList.size() > 0) {
                        this.closeModels(modelList.toArray(new ZooModel[0]));
                        modelList.clear();
                    }
                    throw new MLException(errorMessage, (Throwable)e);
                }
                finally {
                    FileUtils.deleteFileQuietly(this.mlEngine.getLoadModelPath(modelId));
                    Thread.currentThread().setContextClassLoader(contextClassLoader);
                }
            });
        }
        catch (PrivilegedActionException e) {
            String errorMsg = "Failed to load model";
            log.error(errorMsg, (Throwable)e);
            throw new MLException(errorMsg, (Throwable)e);
        }
    }

    private void closePredictors(Predictor[] predictors) {
        log.debug("will close {} predictor for model {}", (Object)predictors.length, (Object)this.modelId);
        for (Predictor predictor : predictors) {
            predictor.close();
        }
    }

    private void closeModels(ZooModel[] models) {
        log.debug("will close {} zoo model for model {}", (Object)models.length, (Object)this.modelId);
        for (ZooModel model : models) {
            model.close();
        }
    }

    protected ModelTensorOutput predictTextEmbedding(String modelId, MLInputDataset inputDataSet) {
        try {
            return AccessController.doPrivileged(() -> {
                Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
                int currentDevice = this.nextDevice.getAndIncrement();
                if (currentDevice > this.devices.length - 1) {
                    this.nextDevice.set((currentDevice %= this.devices.length) + 1);
                }
                if (this.predictors == null) {
                    throw new MLException("model not loaded.");
                }
                ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
                TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet)inputDataSet;
                ModelResultFilter resultFilter = textDocsInput.getResultFilter();
                for (String doc : textDocsInput.getDocs()) {
                    Input input = new Input();
                    input.add(doc);
                    log.debug("run text embedding predict for model {} on device {}", (Object)modelId, (Object)this.devices[currentDevice]);
                    Output output = (Output)this.predictors[currentDevice].predict((Object)input);
                    tensorOutputs.add(this.parseModelTensorOutput(output, resultFilter));
                }
                return new ModelTensorOutput(tensorOutputs);
            });
        }
        catch (PrivilegedActionException e) {
            String errorMsg = "Failed to inference text embedding";
            log.error(errorMsg, (Throwable)e);
            throw new MLException(errorMsg, (Throwable)e);
        }
    }

    protected ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resultFilter) {
        if (output == null) {
            throw new MLException("No output generated");
        }
        byte[] bytes = output.getData().getAsBytes();
        ModelTensors tensorOutput = ModelTensors.fromBytes((byte[])bytes);
        if (resultFilter != null) {
            tensorOutput.filter(resultFilter);
        }
        return tensorOutput;
    }
}

