package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.Utils;
import edu.stanford.nlp.kbp.slotfilling.process.KBPProcess;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.IterableIterator;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.Serializable;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/TrainingStatistics.class */
public class TrainingStatistics implements Serializable {
    private static final long serialVersionUID = 1;
    private final Maybe<? extends Map<SentenceKey, EnsembleStatistics>> impl;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/TrainingStatistics$ActiveLearningSelectionCriterion.class */
    public enum ActiveLearningSelectionCriterion {
        HIGH_KL_FROM_MEAN,
        LOW_AVERAGE_CONFIDENCE,
        RANDOM_UNIFORM
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/TrainingStatistics$EnsembleStatistics.class */
    public static class EnsembleStatistics implements Serializable {
        private static final long serialVersionUID = 1;
        public final Collection<SentenceStatistics> statisticsForClassifiers;
        static final /* synthetic */ boolean $assertionsDisabled;

        public EnsembleStatistics(Collection<SentenceStatistics> collection) {
            this.statisticsForClassifiers = collection;
        }

        public EnsembleStatistics(EnsembleStatistics ensembleStatistics) {
            this.statisticsForClassifiers = new LinkedList(ensembleStatistics.statisticsForClassifiers);
        }

        public void addInPlace(SentenceStatistics sentenceStatistics) {
            this.statisticsForClassifiers.add(sentenceStatistics);
        }

        public void addInPlace(EnsembleStatistics ensembleStatistics) {
            Iterator<SentenceStatistics> it = ensembleStatistics.statisticsForClassifiers.iterator();
            while (it.hasNext()) {
                addInPlace(it.next());
            }
        }

        public SentenceStatistics mean() {
            double d = 0.0d;
            int i = 0;
            ClassicCounter classicCounter = new ClassicCounter(MapFactory.linkedHashMapFactory());
            for (SentenceStatistics sentenceStatistics : this.statisticsForClassifiers) {
                Iterator<Double> it = sentenceStatistics.confidence.iterator();
                while (it.hasNext()) {
                    d += it.next().doubleValue();
                    i++;
                }
                if (!$assertionsDisabled && Math.abs(sentenceStatistics.relationDistribution.totalCount() - 1.0d) >= 1.0E-5d) {
                    throw new AssertionError();
                }
                for (Map.Entry entry : sentenceStatistics.relationDistribution.entrySet()) {
                    if (!$assertionsDisabled && ((Double) entry.getValue()).doubleValue() < 0.0d) {
                        throw new AssertionError();
                    }
                    if (!$assertionsDisabled && ((Double) entry.getValue()).doubleValue() != sentenceStatistics.relationDistribution.getCount(entry.getKey())) {
                        throw new AssertionError();
                    }
                    classicCounter.incrementCount(entry.getKey(), ((Double) entry.getValue()).doubleValue());
                    if (!$assertionsDisabled && sentenceStatistics.relationDistribution.getCount(entry.getKey()) != sentenceStatistics.relationDistribution.getCount(entry.getKey())) {
                        throw new AssertionError();
                    }
                }
            }
            double d2 = d / i;
            if (this.statisticsForClassifiers.size() > 1) {
                Counters.divideInPlace(classicCounter, this.statisticsForClassifiers.size());
            }
            if (Math.abs(classicCounter.totalCount() - 1.0d) > 1.0E-5d) {
                throw new IllegalStateException("Mean relation distribution is not a distribution!");
            }
            if ($assertionsDisabled || this.statisticsForClassifiers.size() > 1 || this.statisticsForClassifiers.size() == 0 || Counters.equals(classicCounter, this.statisticsForClassifiers.iterator().next().relationDistribution, 1.0E-5d)) {
                return i > 0 ? new SentenceStatistics(classicCounter, d2) : new SentenceStatistics(classicCounter);
            }
            throw new AssertionError();
        }

        public double averageKLFromMean() {
            Counter<String> counter = mean().relationDistribution;
            double d = 0.0d;
            Iterator<SentenceStatistics> it = this.statisticsForClassifiers.iterator();
            while (it.hasNext()) {
                double klDivergence = Counters.klDivergence(it.next().relationDistribution, counter);
                if (klDivergence < 0.0d && klDivergence > -1.0E-12d) {
                    klDivergence = 0.0d;
                }
                if (!$assertionsDisabled && klDivergence < 0.0d) {
                    throw new AssertionError();
                }
                d += klDivergence;
            }
            double size = d / this.statisticsForClassifiers.size();
            if (Double.isInfinite(size) || Double.isNaN(size) || size < 0.0d) {
                throw new AssertionError("Invalid average KL value: " + size);
            }
            if (!$assertionsDisabled && size < 0.0d) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this.statisticsForClassifiers.size() <= 1 && size >= 1.0E-5d) {
                throw new AssertionError();
            }
            if (size < 1.0E-10d) {
                size = 0.0d;
            }
            return size;
        }

        static {
            $assertionsDisabled = !TrainingStatistics.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/TrainingStatistics$SentenceKey.class */
    public static class SentenceKey implements Serializable {
        private static final long serialVersionUID = 1;
        public final String sentenceHash;

        public SentenceKey(String str) {
            this.sentenceHash = str;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj instanceof SentenceKey) {
                return this.sentenceHash.equals(((SentenceKey) obj).sentenceHash);
            }
            return false;
        }

        public int hashCode() {
            return this.sentenceHash.hashCode();
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/TrainingStatistics$SentenceStatistics.class */
    public static class SentenceStatistics implements Serializable {
        private static final long serialVersionUID = 1;
        public final Maybe<Double> confidence;
        public final Counter<String> relationDistribution;

        public SentenceStatistics(Counter<String> counter, double d) {
            this.confidence = Maybe.Just(Double.valueOf(d));
            this.relationDistribution = counter;
        }

        public SentenceStatistics(Counter<String> counter) {
            this.confidence = Maybe.Nothing();
            this.relationDistribution = counter;
        }
    }

    public TrainingStatistics(Maybe<? extends Map<SentenceKey, EnsembleStatistics>> maybe) {
        this.impl = maybe;
    }

    public void addInPlace(SentenceKey sentenceKey, SentenceStatistics sentenceStatistics) {
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it = this.impl.iterator();
        while (it.hasNext()) {
            Map<SentenceKey, EnsembleStatistics> next = it.next();
            EnsembleStatistics ensembleStatistics = next.get(sentenceKey);
            if (ensembleStatistics == null) {
                ensembleStatistics = new EnsembleStatistics(new LinkedList());
                next.put(sentenceKey, ensembleStatistics);
            }
            ensembleStatistics.addInPlace(sentenceStatistics);
        }
    }

    public Maybe<Set<SentenceKey>> getSentenceKeys() {
        return !this.impl.isDefined() ? Maybe.Nothing() : Maybe.Just(this.impl.get().keySet());
    }

    public TrainingStatistics merge(TrainingStatistics trainingStatistics) {
        HashMap hashMap = new HashMap();
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it = this.impl.iterator();
        while (it.hasNext()) {
            Map<SentenceKey, EnsembleStatistics> next = it.next();
            for (SentenceKey sentenceKey : next.keySet()) {
                hashMap.put(sentenceKey, new EnsembleStatistics(next.get(sentenceKey)));
            }
        }
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it2 = trainingStatistics.impl.iterator();
        while (it2.hasNext()) {
            Map<SentenceKey, EnsembleStatistics> next2 = it2.next();
            for (SentenceKey sentenceKey2 : next2.keySet()) {
                EnsembleStatistics ensembleStatistics = (EnsembleStatistics) hashMap.get(sentenceKey2);
                if (ensembleStatistics == null) {
                    ensembleStatistics = new EnsembleStatistics(new LinkedList());
                    hashMap.put(sentenceKey2, ensembleStatistics);
                }
                ensembleStatistics.addInPlace(next2.get(sentenceKey2));
            }
        }
        return new TrainingStatistics(Maybe.Just(hashMap));
    }

    private IterableIterator<Pair<CoreMap, Counter<String>>> iterableFromList(final KBPProcess kBPProcess, final List<String> list) {
        return new IterableIterator<>(new Iterator<Pair<CoreMap, Counter<String>>>() { // from class: edu.stanford.nlp.kbp.slotfilling.classify.TrainingStatistics.1
            Iterator<String> iter;
            Maybe<Pair<CoreMap, Counter<String>>> nextGloss = Maybe.Nothing();
            static final /* synthetic */ boolean $assertionsDisabled;

            {
                this.iter = list.iterator();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                if (this.nextGloss.isDefined()) {
                    return true;
                }
                while (this.iter.hasNext() && !this.nextGloss.isDefined()) {
                    String next = this.iter.next();
                    if (kBPProcess.recoverSentenceGloss(next).isDefined()) {
                        this.nextGloss = Maybe.Just(Pair.makePair(kBPProcess.recoverSentenceGloss(next).get(), TrainingStatistics.this.relationPredictionsForKey(next)));
                    } else {
                        this.nextGloss = Maybe.Nothing();
                    }
                    if (!this.nextGloss.isDefined()) {
                        Redwood.Util.warn(new Object[]{"No sentence for gloss key: " + next});
                    }
                }
                return this.nextGloss.isDefined();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<CoreMap, Counter<String>> next() {
                if (!hasNext()) {
                    throw new NoSuchElementException();
                }
                if (!$assertionsDisabled && !this.nextGloss.isDefined()) {
                    throw new AssertionError();
                }
                Pair<CoreMap, Counter<String>> pair = this.nextGloss.get();
                this.nextGloss = Maybe.Nothing();
                return pair;
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }

            static {
                $assertionsDisabled = !TrainingStatistics.class.desiredAssertionStatus();
            }
        });
    }

    private Counter<String> highKLFromMean() {
        ClassicCounter classicCounter = new ClassicCounter(MapFactory.linkedHashMapFactory());
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it = this.impl.iterator();
        while (it.hasNext()) {
            for (Map.Entry<SentenceKey, EnsembleStatistics> entry : it.next().entrySet()) {
                classicCounter.setCount(entry.getKey().sentenceHash, entry.getValue().averageKLFromMean());
            }
        }
        return classicCounter;
    }

    private Counter<String> lowAverageConfidence() {
        ClassicCounter classicCounter = new ClassicCounter(MapFactory.linkedHashMapFactory());
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it = this.impl.iterator();
        while (it.hasNext()) {
            for (Map.Entry<SentenceKey, EnsembleStatistics> entry : it.next().entrySet()) {
                Iterator<Double> it2 = entry.getValue().mean().confidence.iterator();
                while (it2.hasNext()) {
                    classicCounter.setCount(entry.getKey().sentenceHash, 1.0d - it2.next().doubleValue());
                }
            }
        }
        return classicCounter;
    }

    private Counter<String> uniformRandom() {
        ClassicCounter classicCounter = new ClassicCounter(MapFactory.linkedHashMapFactory());
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it = this.impl.iterator();
        while (it.hasNext()) {
            Iterator<Map.Entry<SentenceKey, EnsembleStatistics>> it2 = it.next().entrySet().iterator();
            while (it2.hasNext()) {
                classicCounter.setCount(it2.next().getKey().sentenceHash, 1.0d);
            }
        }
        return classicCounter;
    }

    public IterableIterator<Pair<CoreMap, Counter<String>>> selectExamples(KBPProcess kBPProcess, ActiveLearningSelectionCriterion activeLearningSelectionCriterion) {
        return iterableFromList(kBPProcess, selectKeys(activeLearningSelectionCriterion));
    }

    public List<String> selectKeys(ActiveLearningSelectionCriterion activeLearningSelectionCriterion) {
        return Counters.toSortedList(uncertainty(activeLearningSelectionCriterion));
    }

    public List<Pair<String, Double>> selectWeightedKeys(ActiveLearningSelectionCriterion activeLearningSelectionCriterion) {
        return Counters.toSortedListWithCounts(uncertainty(activeLearningSelectionCriterion));
    }

    public Counter<String> uncertainty(ActiveLearningSelectionCriterion activeLearningSelectionCriterion) {
        switch (activeLearningSelectionCriterion) {
            case HIGH_KL_FROM_MEAN:
                return highKLFromMean();
            case LOW_AVERAGE_CONFIDENCE:
                return lowAverageConfidence();
            case RANDOM_UNIFORM:
                return uniformRandom();
            default:
                throw new IllegalArgumentException("Unknown selection criterion: " + activeLearningSelectionCriterion);
        }
    }

    public List<String> selectKeysWithSampling(ActiveLearningSelectionCriterion activeLearningSelectionCriterion, int i, int i2) {
        final List<Pair<String, Double>> selectWeightedKeysWithSampling = selectWeightedKeysWithSampling(activeLearningSelectionCriterion, i, i2);
        return new AbstractList<String>() { // from class: edu.stanford.nlp.kbp.slotfilling.classify.TrainingStatistics.2
            @Override // java.util.AbstractList, java.util.List
            public String get(int i3) {
                return (String) ((Pair) selectWeightedKeysWithSampling.get(i3)).first;
            }

            @Override // java.util.AbstractCollection, java.util.Collection, java.util.List
            public int size() {
                return selectWeightedKeysWithSampling.size();
            }
        };
    }

    public List<Pair<String, Double>> selectWeightedKeysWithSampling(ActiveLearningSelectionCriterion activeLearningSelectionCriterion, int i, int i2) {
        ArrayList arrayList = new ArrayList();
        Redwood.Util.forceTrack("Sampling Keys");
        Redwood.Util.log(new Object[]{"" + i + " to collect"});
        Redwood.Util.forceTrack("Computing Uncertainties");
        Counter<String> uncertainty = uncertainty(activeLearningSelectionCriterion);
        if (!$assertionsDisabled && !uncertainty.equals(uncertainty(activeLearningSelectionCriterion))) {
            throw new AssertionError();
        }
        Redwood.Util.endTrack("Computing Uncertainties");
        Redwood.Util.startTrack(new Object[]{"Uncertainty Histogram"});
        Redwood.Util.endTrack("Uncertainty Histogram");
        double d = uncertainty.totalCount();
        Random random = new Random(i2);
        LinkedList linkedList = new LinkedList();
        LinkedList<Double> linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        for (Pair pair : Counters.toSortedListWithCounts(uncertainty, (pair2, pair3) -> {
            int compareTo = pair2.compareTo(pair3);
            return compareTo == 0 ? ((String) pair2.first).compareTo((String) pair3.first) : compareTo;
        })) {
            if (((Double) pair.second).doubleValue() != 0.0d || uncertainty.totalCount() == 0.0d || uncertainty.size() <= i) {
                linkedList.add(pair.first);
                linkedList2.add(pair.second);
            } else {
                linkedList3.add(pair.first);
            }
        }
        if (Utils.assertionsEnabled()) {
            for (Double d2 : linkedList2) {
                if (d2.doubleValue() < 0.0d || Double.isInfinite(d2.doubleValue()) || Double.isNaN(d2.doubleValue())) {
                    throw new IllegalArgumentException("Invalid weight: " + d2);
                }
            }
        }
        for (int i3 = 1; i3 <= i; i3++) {
            if (i3 % Props.MAX_DISTANCE_BETWEEN_ENTITY_AND_SLOT == 0) {
                Redwood.Util.log(new Object[]{"sampled " + (i3 / Props.MAX_DISTANCE_BETWEEN_ENTITY_AND_SLOT) + "k keys"});
                d = 0.0d;
                Iterator it = linkedList2.iterator();
                while (it.hasNext()) {
                    d += ((Double) it.next()).doubleValue();
                }
            }
            if (linkedList2.size() != 0) {
                if (!$assertionsDisabled && d < 0.0d) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && linkedList2.size() != linkedList.size()) {
                    throw new AssertionError();
                }
                double nextDouble = random.nextDouble() * d;
                Iterator it2 = linkedList.iterator();
                Iterator it3 = linkedList2.iterator();
                double d3 = 0.0d;
                while (true) {
                    if (it2.hasNext()) {
                        String str = (String) it2.next();
                        double doubleValue = ((Double) it3.next()).doubleValue();
                        d3 += doubleValue;
                        if (nextDouble <= d3) {
                            arrayList.add(Pair.makePair(str, Double.valueOf(doubleValue)));
                            it2.remove();
                            it3.remove();
                            d -= doubleValue;
                            break;
                        }
                    } else {
                        Redwood.Util.warn(new Object[]{"No more uncertain samples left to draw from! (target=" + nextDouble + " totalCount=" + d + " size=" + linkedList.size()});
                        if (!$assertionsDisabled && linkedList.size() != 0) {
                            throw new AssertionError();
                        }
                        if (linkedList3.size() <= 0) {
                            break;
                        }
                        arrayList.add(Pair.makePair(linkedList3.remove(0), Double.valueOf(0.0d)));
                    }
                }
            }
        }
        Redwood.Util.endTrack("Sampling Keys");
        return arrayList;
    }

    public IterableIterator<Pair<CoreMap, Counter<String>>> selectExamplesWithSampling(KBPProcess kBPProcess, ActiveLearningSelectionCriterion activeLearningSelectionCriterion, int i, int i2) {
        return iterableFromList(kBPProcess, selectKeysWithSampling(activeLearningSelectionCriterion, i, i2));
    }

    public Counter<String> relationPredictionsForKey(String str) {
        if (this.impl.isDefined()) {
            return this.impl.get().get(new SentenceKey(str)).mean().relationDistribution;
        }
        throw new IllegalArgumentException("Training statistics is not defined");
    }

    public void validate() {
        Iterator<? extends Map<SentenceKey, EnsembleStatistics>> it = this.impl.iterator();
        while (it.hasNext()) {
            for (EnsembleStatistics ensembleStatistics : it.next().values()) {
                for (SentenceStatistics sentenceStatistics : ensembleStatistics.statisticsForClassifiers) {
                    if (!$assertionsDisabled && Counters.isUniformDistribution(sentenceStatistics.relationDistribution, 1.0E-5d)) {
                        throw new AssertionError();
                    }
                    Counters.normalize(sentenceStatistics.relationDistribution);
                    if (!$assertionsDisabled && Math.abs(sentenceStatistics.relationDistribution.totalCount() - 1.0d) >= 1.0E-5d) {
                        throw new AssertionError();
                    }
                }
                if (!$assertionsDisabled && Math.abs(ensembleStatistics.mean().relationDistribution.totalCount() - 1.0d) >= 1.0E-5d) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && Counters.isUniformDistribution(ensembleStatistics.mean().relationDistribution, 1.0E-5d)) {
                    throw new AssertionError();
                }
            }
        }
    }

    public static TrainingStatistics empty() {
        return new TrainingStatistics(Maybe.Just(new HashMap()));
    }

    public static TrainingStatistics undefined() {
        return new TrainingStatistics(Maybe.Nothing());
    }

    static {
        $assertionsDisabled = !TrainingStatistics.class.desiredAssertionStatus();
    }
}
