/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_embedding;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.ServingTranslator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Map;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;

public class ONNXSentenceTransformerTextEmbeddingTranslator
implements ServingTranslator {
    private static final int[] AXIS = new int[]{0};
    private HuggingFaceTokenizer tokenizer;

    public Batchifier getBatchifier() {
        return null;
    }

    public void prepare(TranslatorContext ctx) throws IOException {
        Path path = ctx.getModel().getModelPath();
        this.tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(path.resolve("tokenizer.json")).build();
    }

    public NDList processInput(TranslatorContext ctx, Input input) {
        NDManager manager = ctx.getNDManager();
        String sentence = input.getAsString(0);
        NDList ndList = new NDList();
        Encoding encode = this.tokenizer.encode(sentence);
        ctx.setAttachment("encoding", (Object)encode);
        long[] indices = encode.getIds();
        long[] attentionMask = encode.getAttentionMask();
        long[] tokenTypeIds = encode.getTypeIds();
        NDArray indicesArray = manager.create(indices);
        indicesArray.setName("input_ids");
        NDArray attentionMaskArray = manager.create(attentionMask);
        attentionMaskArray.setName("attention_mask");
        NDArray tokenTypeIdsArray = manager.create(tokenTypeIds);
        tokenTypeIdsArray.setName("token_type_ids");
        ndList.add((Object)indicesArray.expandDims(0));
        ndList.add((Object)tokenTypeIdsArray.expandDims(0));
        ndList.add((Object)attentionMaskArray.expandDims(0));
        return ndList;
    }

    public Output processOutput(TranslatorContext ctx, NDList list) {
        NDArray embeddings = (NDArray)list.get(0);
        int shapeLength = embeddings.getShape().getShape().length;
        if (shapeLength == 3) {
            embeddings = embeddings.get(new long[]{0L});
        }
        Encoding encoding = (Encoding)ctx.getAttachment("encoding");
        long[] attentionMask = encoding.getAttentionMask();
        NDManager manager = ctx.getNDManager();
        NDArray inputAttentionMask = manager.create(attentionMask);
        long[] shape = embeddings.getShape().getShape();
        inputAttentionMask = inputAttentionMask.expandDims(-1).broadcast(shape);
        NDArray inputAttentionMaskSum = inputAttentionMask.sum(AXIS);
        NDArray clamp = inputAttentionMaskSum.clip((Number)1.0E-9, (Number)1.0E12);
        NDArray prod = embeddings.mul(inputAttentionMask);
        NDArray sum = prod.sum(AXIS);
        embeddings = sum.div(clamp).normalize(2.0, 0L);
        ArrayList<ModelTensor> outputs = new ArrayList<ModelTensor>();
        Number[] data = embeddings.toArray();
        outputs.add(new ModelTensor("sentence_embedding", data, shape, MLResultDataType.FLOAT32, null));
        Output output = new Output();
        ModelTensors modelTensorOutput = new ModelTensors(outputs);
        output.add(modelTensorOutput.toBytes());
        return output;
    }

    public void setArguments(Map<String, ?> arguments) {
    }
}

