package eus.ixa.ixa.pipe.ml.document;

import eus.ixa.ixa.pipe.ml.document.features.BagOfWordsFeatureGenerator;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/* loaded from: input_file:eus/ixa/ixa/pipe/ml/document/DocumentClassifierME.class */
public class DocumentClassifierME implements DocumentClassifier {
    private MaxentModel model;
    private DocumentClassifierContextGenerator contextGenerator;

    public DocumentClassifierME(DocumentClassifierModel documentClassifierModel) {
        DocumentClassifierFactory factory = documentClassifierModel.getFactory();
        this.model = documentClassifierModel.getDocumentClassifierModel();
        this.contextGenerator = factory.createContextGenerator();
        this.contextGenerator.addFeatureGenerator(new BagOfWordsFeatureGenerator());
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public String classify(String[] strArr) {
        return this.model.getBestOutcome(classifyProb(strArr));
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public double[] classifyProb(String[] strArr) {
        return this.model.eval(this.contextGenerator.getContext(strArr));
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public String getBestLabel(double[] dArr) {
        return this.model.getBestOutcome(dArr);
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public Map<String, Double> scoreMap(String[] strArr) {
        HashMap hashMap = new HashMap();
        double[] classifyProb = classifyProb(strArr);
        int numberOfLabels = getNumberOfLabels();
        for (int i = 0; i < numberOfLabels; i++) {
            String label = getLabel(i);
            hashMap.put(label, Double.valueOf(classifyProb[getIndex(label)]));
        }
        return hashMap;
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public SortedMap<Double, Set<String>> sortedScoreMap(String[] strArr) {
        TreeMap treeMap = new TreeMap();
        double[] classifyProb = classifyProb(strArr);
        int numberOfLabels = getNumberOfLabels();
        for (int i = 0; i < numberOfLabels; i++) {
            String label = getLabel(i);
            double d = classifyProb[getIndex(label)];
            if (treeMap.containsKey(Double.valueOf(d))) {
                ((Set) treeMap.get(Double.valueOf(d))).add(label);
            } else {
                HashSet hashSet = new HashSet();
                hashSet.add(label);
                treeMap.put(Double.valueOf(d), hashSet);
            }
        }
        return treeMap;
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public int getIndex(String str) {
        return this.model.getIndex(str);
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public String getLabel(int i) {
        return this.model.getOutcome(i);
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public int getNumberOfLabels() {
        return this.model.getNumOutcomes();
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public String getAllLabels(double[] dArr) {
        return this.model.getAllOutcomes(dArr);
    }

    @Override // eus.ixa.ixa.pipe.ml.document.DocumentClassifier
    public void clearFeatureData() {
        this.contextGenerator.clearFeatureData();
    }

    public static DocumentClassifierModel train(String str, ObjectStream<DocSample> objectStream, TrainingParameters trainingParameters, DocumentClassifierFactory documentClassifierFactory) throws IOException {
        HashMap hashMap = new HashMap();
        return new DocumentClassifierModel(str, TrainerFactory.getEventTrainer(trainingParameters.getSettings(), hashMap).train(new DocumentClassifierEventStream(objectStream, documentClassifierFactory.createContextGenerator())), documentClassifierFactory.getFeatureGenerator(), documentClassifierFactory.getResources(), hashMap, documentClassifierFactory);
    }
}
