package ai.djl.pytorch.zoo.nlp.sentimentanalysis;

import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.StackBatchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/pytorch/zoo/nlp/sentimentanalysis/PtDistilBertTranslator.class */
public class PtDistilBertTranslator implements Translator<String, Classifications> {
    private Vocabulary vocabulary;
    private BertTokenizer tokenizer;

    public Batchifier getBatchifier() {
        return new StackBatchifier();
    }

    public void prepare(NDManager nDManager, Model model) throws IOException {
        this.vocabulary = SimpleVocabulary.builder().optMinFrequency(1).addFromTextFile(model.getArtifact("distilbert-base-uncased-finetuned-sst-2-english-vocab.txt").getPath()).optUnknownToken("[UNK]").build();
        this.tokenizer = new BertTokenizer();
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public Classifications m7processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        return new Classifications(Arrays.asList("Negative", "Positive"), singletonOrThrow.exp().div(singletonOrThrow.exp().sum(new int[]{0}, true)));
    }

    public NDList processInput(TranslatorContext translatorContext, String str) {
        List list = this.tokenizer.tokenize(str);
        Stream stream = list.stream();
        Vocabulary vocabulary = this.vocabulary;
        vocabulary.getClass();
        long[] array = stream.mapToLong(vocabulary::getIndex).toArray();
        long[] jArr = new long[list.size()];
        Arrays.fill(jArr, 1L);
        NDManager nDManager = translatorContext.getNDManager();
        return new NDList(new NDArray[]{nDManager.create(array), nDManager.create(jArr)});
    }
}
