package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.classify.ShiftParamsLogisticClassifierFactory;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPRelationProvenance;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/OneVsAllRelationExtractor.class */
public class OneVsAllRelationExtractor extends RelationClassifier {
    private Map<String, ProbabilisticClassifier<String, String>> classifiers;
    private double sigma;
    public final double gamma;
    public final boolean useRobustLR;
    static final /* synthetic */ boolean $assertionsDisabled;

    public OneVsAllRelationExtractor(Properties properties) {
        this(false);
    }

    public OneVsAllRelationExtractor(Properties properties, boolean z) {
        this(z);
    }

    public OneVsAllRelationExtractor(boolean z) {
        this(z, 1.0d, 1.0d);
    }

    public OneVsAllRelationExtractor(boolean z, double d, double d2) {
        this.classifiers = null;
        this.useRobustLR = z;
        this.sigma = d;
        this.gamma = d2;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        return RelationClassifier.firstProvenance(classifyMentions(RelationClassifier.tupleToFeatureList(sentenceGroup)), sentenceGroup);
    }

    public Counter<String> classifyMentions(List<Collection<String>> list) {
        if (!$assertionsDisabled && this.classifiers == null) {
            throw new AssertionError();
        }
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator<Collection<String>> it = list.iterator();
        while (it.hasNext()) {
            Pair<String, Double> annotateDatum = annotateDatum(new BasicDatum(it.next()));
            if (!((String) annotateDatum.first()).equals("_NR")) {
                classicCounter.incrementCount(annotateDatum.first(), ((Double) annotateDatum.second()).doubleValue());
            }
        }
        Counters.normalize(classicCounter);
        return classicCounter;
    }

    private Pair<String, Double> annotateDatum(Datum<String, String> datum) {
        Set<String> keySet = this.classifiers.keySet();
        ArrayList<Pair> arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<String> it = keySet.iterator();
        while (it.hasNext()) {
            Pair<String, Double> classOf = classOf(datum, this.classifiers.get(it.next()));
            if (classOf != null) {
                if (((Double) classOf.second).doubleValue() > 0.5d) {
                    arrayList.add(classOf);
                }
                arrayList2.add(classOf.second);
            }
        }
        for (Pair pair : arrayList) {
            pair.second = Double.valueOf(softmax(((Double) pair.second).doubleValue(), arrayList2, this.gamma));
        }
        Collections.sort(arrayList, (pair2, pair3) -> {
            if (((Double) pair2.second).doubleValue() > ((Double) pair3.second).doubleValue()) {
                return -1;
            }
            return ((Double) pair2.second).equals(pair3.second) ? 0 : 1;
        });
        return arrayList.size() > 0 ? (Pair) arrayList.iterator().next() : new Pair<>("_NR", Double.valueOf(1.0d));
    }

    private Pair<String, Double> classOf(Datum<String, String> datum, ProbabilisticClassifier<String, String> probabilisticClassifier) {
        for (Pair<String, Double> pair : Counters.toDescendingMagnitudeSortedListWithCounts(probabilisticClassifier.probabilityOf(datum))) {
            if (!((String) pair.first).equals("_NR")) {
                return pair;
            }
        }
        return null;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void save(ObjectOutputStream objectOutputStream) throws IOException {
        if (!$assertionsDisabled && this.classifiers == null) {
            throw new AssertionError();
        }
        objectOutputStream.writeObject(this.classifiers);
        objectOutputStream.writeDouble(this.sigma);
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        Map<String, ProbabilisticClassifier<String, String>> map = (Map) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        double readDouble = objectInputStream.readDouble();
        this.classifiers = map;
        this.sigma = readDouble;
        objectInputStream.close();
    }

    public static OneVsAllRelationExtractor load(String str) throws IOException, ClassNotFoundException {
        return (OneVsAllRelationExtractor) RelationClassifier.load(str, new Properties(), OneVsAllRelationExtractor.class);
    }

    public TrainingStatistics train(Map<String, GeneralDataset<String, String>> map) {
        Set<String> keySet = map.keySet();
        Redwood.Util.startTrack(new Object[]{Redwood.Util.BOLD, Redwood.Util.BLUE, "Training " + keySet.size() + " models"});
        this.classifiers = new HashMap();
        for (String str : keySet) {
            Redwood.Util.startTrack(new Object[]{"Training classifier for label: " + str});
            Redwood.log(new Object[]{"Train set size: " + map.get(str).size()});
            this.classifiers.put(str, trainOne(map.get(str)));
            Redwood.Util.endTrack("Training classifier for label: " + str);
        }
        Redwood.Util.endTrack("Training " + keySet.size() + " models");
        this.statistics = Maybe.Just(TrainingStatistics.undefined());
        return TrainingStatistics.undefined();
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public TrainingStatistics train(KBPDataset<String, String> kBPDataset) {
        String str;
        HashMap hashMap = new HashMap();
        for (String str2 : kBPDataset.labelIndex) {
            if (!$assertionsDisabled && !RelationType.fromString(str2).orCrash().canonicalName.equals(str2)) {
                throw new AssertionError();
            }
            Dataset dataset = new Dataset();
            for (int i = 0; i < kBPDataset.size(); i++) {
                List<Datum<String, String>> datumGroup = kBPDataset.getDatumGroup(i);
                if (kBPDataset.getPositiveLabels(i).contains(str2)) {
                    str = str2;
                } else if (Props.TRAIN_LR_ALLNEGATIVES || kBPDataset.getNegativeLabels(i).contains(str2)) {
                    str = "_NR";
                }
                Iterator<Datum<String, String>> it = datumGroup.iterator();
                while (it.hasNext()) {
                    dataset.add(new BasicDatum(it.next().asFeatures(), str));
                }
            }
            if (dataset.size() > 0) {
                hashMap.put(str2, dataset);
            }
        }
        return train(hashMap);
    }

    private ProbabilisticClassifier<String, String> trainOne(GeneralDataset<String, String> generalDataset) {
        if (this.useRobustLR) {
            return new ShiftParamsLogisticClassifierFactory(new LogPrior(LogPrior.LogPriorType.QUADRATIC, 1.0d, 0.1d), 0.01d).m0trainClassifier((GeneralDataset) generalDataset);
        }
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(1.0E-4d, false, this.sigma);
        linearClassifierFactory.setVerbose(false);
        return linearClassifierFactory.trainClassifier(generalDataset);
    }

    static {
        $assertionsDisabled = !OneVsAllRelationExtractor.class.desiredAssertionStatus();
    }
}
