package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Collection;

/* loaded from: input_file:edu/stanford/nlp/classify/MultinomialLogisticClassifier.class */
public class MultinomialLogisticClassifier<L, F> implements ProbabilisticClassifier<L, F>, RVFClassifier<L, F> {
    private static final long serialVersionUID = 1;
    private double[][] weights;
    private Index<F> featureIndex;
    private Index<L> labelIndex;

    public MultinomialLogisticClassifier(double[][] dArr, Index<F> index, Index<L> index2) {
        this.featureIndex = index;
        this.labelIndex = index2;
        this.weights = dArr;
    }

    public Collection<L> labels() {
        return this.labelIndex.objectsList();
    }

    public L classOf(Datum<L, F> datum) {
        return (L) Counters.argmax(scoresOf(datum));
    }

    public Counter<L> scoresOf(Datum<L, F> datum) {
        return logProbabilityOf(datum);
    }

    public L classOf(RVFDatum<L, F> rVFDatum) {
        return classOf((Datum) rVFDatum);
    }

    public Counter<L> scoresOf(RVFDatum<L, F> rVFDatum) {
        return scoresOf((Datum) rVFDatum);
    }

    public Counter<L> probabilityOf(Datum<L, F> datum) {
        double[] dArr;
        int[] indicesOf = LogisticUtils.indicesOf(datum.asFeatures(), this.featureIndex);
        if (datum instanceof RVFDatum) {
            dArr = LogisticUtils.convertToArray(((RVFDatum) datum).asFeaturesCounter().values());
        } else {
            dArr = new double[datum.asFeatures().size()];
            Arrays.fill(dArr, 1.0d);
        }
        ClassicCounter classicCounter = new ClassicCounter();
        int size = this.labelIndex.size();
        double[] calculateSigmoids = LogisticUtils.calculateSigmoids(this.weights, indicesOf, dArr);
        for (int i = 0; i < size; i++) {
            classicCounter.incrementCount(this.labelIndex.get(i), calculateSigmoids[i]);
        }
        return classicCounter;
    }

    public Counter<L> logProbabilityOf(Datum<L, F> datum) {
        Counter<L> probabilityOf = probabilityOf(datum);
        Counters.logInPlace(probabilityOf);
        return probabilityOf;
    }

    private void load(String str) throws IOException, ClassNotFoundException {
        System.out.print("Loading classifier from " + str + "... ");
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
        this.weights = (double[][]) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.featureIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.labelIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        objectInputStream.close();
        System.out.println("done.");
    }

    private void save(String str) throws IOException {
        System.out.print("Saving classifier to " + str + "... ");
        int lastIndexOf = str.lastIndexOf(File.separator);
        if (lastIndexOf > 0) {
            File file = new File(str.substring(0, lastIndexOf));
            if (!file.exists()) {
                file.mkdirs();
            }
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
        objectOutputStream.writeObject(this.weights);
        objectOutputStream.writeObject(this.featureIndex);
        objectOutputStream.writeObject(this.labelIndex);
        objectOutputStream.close();
        System.out.println("done.");
    }
}
