package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPRelationProvenance;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.text.DecimalFormat;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/SupervisedExtractor.class */
public class SupervisedExtractor extends RelationClassifier {
    private LinearClassifier<String, String> classifier;
    private boolean trained = false;
    private Index<String> labelIndexOrNull = null;
    private Index<String> featureIndexOrNull = null;

    public SupervisedExtractor(Properties properties) {
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public TrainingStatistics train(KBPDataset<String, String> kBPDataset) {
        this.trained = true;
        WeightedDataset weightedDataset = new WeightedDataset();
        if (this.labelIndexOrNull != null) {
            weightedDataset.labelIndex = this.labelIndexOrNull;
        }
        if (this.featureIndexOrNull != null) {
            weightedDataset.featureIndex = this.featureIndexOrNull;
        }
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i = 0; i < kBPDataset.size(); i++) {
            List<Datum<String, String>> datumGroup = kBPDataset.getDatumGroup(i);
            for (int i2 = 0; i2 < kBPDataset.getNumSentencesInGroup(i); i2++) {
                Iterator<String> it = kBPDataset.getAnnotatedLabels(i)[i2].iterator();
                while (it.hasNext()) {
                    String next = it.next();
                    if (!next.equals("_NR")) {
                        classicCounter.incrementCount(RelationType.fromString(next).orCrash());
                    }
                    weightedDataset.add(new BasicDatum(datumGroup.get(i2).asFeatures(), next));
                }
            }
        }
        Redwood.Util.log(new Object[]{Redwood.Util.BLUE, "training on " + weightedDataset.size() + " annotated datums"});
        Counters.normalize(classicCounter);
        if (Props.TRAIN_SUPERVISED_REWEIGHT) {
            Counter division = Counters.division(RelationType.priors, classicCounter);
            Redwood.Util.startTrack(new Object[]{"Scaling factors"});
            for (Pair pair : Counters.toDescendingMagnitudeSortedListWithCounts(division)) {
                Redwood.Util.log(new Object[]{new DecimalFormat("0.00000").format(pair.second) + "\t" + pair.first});
            }
            Redwood.Util.endTrack("Scaling factors");
            for (int i3 = 0; i3 < weightedDataset.size(); i3++) {
                Iterator<RelationType> it2 = RelationType.fromString((String) weightedDataset.getDatum(i3).label()).iterator();
                while (it2.hasNext()) {
                    weightedDataset.setWeight(i3, (float) division.getCount(it2.next()));
                }
            }
        }
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(1.0E-4d, false, Props.TRAIN_JOINTBAYES_ZSIGMA);
        switch (Props.TRAIN_JOINTBAYES_ZMINIMIZER) {
            case SGD:
                linearClassifierFactory.useInPlaceStochasticGradientDescent(75, Props.MAX_DISTANCE_BETWEEN_ENTITY_AND_SLOT, Props.TRAIN_JOINTBAYES_ZSIGMA);
                break;
            case SGDTOQN:
                linearClassifierFactory.useHybridMinimizerWithInPlaceSGD(10, Props.MAX_DISTANCE_BETWEEN_ENTITY_AND_SLOT, Props.TRAIN_JOINTBAYES_ZSIGMA);
                break;
        }
        this.classifier = linearClassifierFactory.trainClassifier(weightedDataset);
        return TrainingStatistics.empty();
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i = 0; i < sentenceGroup.size(); i++) {
            Counter probabilityOf = this.classifier.probabilityOf(sentenceGroup.get(i));
            String str = (String) Counters.argmax(probabilityOf);
            if (!str.equals("_NR") && probabilityOf.getCount(str) > Props.TRAIN_SUPERVISED_THRESHOLD) {
                classicCounter.incrementCount(Pair.makePair(str, Maybe.Just(sentenceGroup.getProvenance(i))), probabilityOf.getCount(str));
            }
        }
        return classicCounter;
    }

    public LinearClassifier<String, String> asClassifier(Maybe<KBPDataset<String, String>> maybe) {
        if (!this.trained && maybe.isDefined()) {
            train(maybe.get());
            return this.classifier;
        }
        if (this.trained) {
            return this.classifier;
        }
        throw new IllegalArgumentException("This classifier is not trained, and no dataset was provided");
    }

    public SupervisedExtractor setIndices(Index<String> index, Index<String> index2) {
        this.labelIndexOrNull = index;
        this.featureIndexOrNull = index2;
        return this;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        this.classifier = (LinearClassifier) objectInputStream.readObject();
        this.trained = true;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void save(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeObject(this.classifier);
    }
}
