package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPRelationProvenance;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/PerceptronExtractor.class */
public class PerceptronExtractor extends RelationClassifier {
    private static final long serialVersionUID = 1;
    private static boolean SOFT_UNKNOWN;
    LabelWeights[] zWeights;
    Index<String> labelIndex;
    Index<String> zFeatureIndex;
    int nilIndex;
    final ModelType modelType;
    final int epochs;
    final boolean softmaxEnabled;
    final double gamma;
    final boolean verbose;
    private Counter<Integer> posUpdateStats;
    private Counter<Integer> negUpdateStats;
    private Counter<Integer> unknownUpdateStats;
    private static final int LABEL_ALL = -1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/PerceptronExtractor$LabelWeights.class */
    public static class LabelWeights implements Serializable {
        private static final long serialVersionUID = 1;
        double[] weights;
        int survivalIterations;
        double[] avgWeights;

        LabelWeights(int i) {
            this.weights = new double[i];
            Arrays.fill(this.weights, 0.0d);
            this.survivalIterations = 0;
            this.avgWeights = new double[i];
            Arrays.fill(this.avgWeights, 0.0d);
        }

        void clear() {
            this.weights = null;
        }

        void updateSurvivalIterations() {
            this.survivalIterations++;
        }

        private void addToAverage() {
            double d = this.survivalIterations;
            for (int i = 0; i < this.weights.length; i++) {
                double[] dArr = this.avgWeights;
                int i2 = i;
                dArr[i2] = dArr[i2] + (this.weights[i] * d);
            }
        }

        void update(int[] iArr, double d) {
            addToAverage();
            for (int i : iArr) {
                if (i > this.weights.length) {
                    expand();
                }
                double[] dArr = this.weights;
                dArr[i] = dArr[i] + d;
            }
            this.survivalIterations = 0;
        }

        private void expand() {
            throw new RuntimeException("ERROR: LabelWeights.expand() not supported yet!");
        }

        double dotProduct(Counter<Integer> counter) {
            return dotProduct(counter, this.weights);
        }

        void normalize(double d) {
            if (d > 0.0d) {
                for (int i = 0; i < this.avgWeights.length; i++) {
                    double[] dArr = this.avgWeights;
                    int i2 = i;
                    dArr[i2] = dArr[i2] / d;
                }
            }
        }

        double avgDotProduct(Collection<String> collection, Index<String> index) {
            ClassicCounter classicCounter = new ClassicCounter();
            Iterator<String> it = collection.iterator();
            while (it.hasNext()) {
                int indexOf = index.indexOf(it.next());
                if (indexOf >= 0) {
                    classicCounter.incrementCount(Integer.valueOf(indexOf));
                }
            }
            return dotProduct(classicCounter, this.avgWeights);
        }

        static double dotProduct(Counter<Integer> counter, double[] dArr) {
            double d = 0.0d;
            for (Map.Entry entry : counter.entrySet()) {
                if (entry.getKey() == null) {
                    throw new RuntimeException("NULL key in " + entry.getKey() + "/" + entry.getValue());
                }
                if (entry.getValue() == null) {
                    throw new RuntimeException("NULL value in " + entry.getKey() + "/" + entry.getValue());
                }
                if (dArr == null) {
                    throw new RuntimeException("NULL weights!");
                }
                if (((Integer) entry.getKey()).intValue() < 0 || ((Integer) entry.getKey()).intValue() >= dArr.length) {
                    throw new RuntimeException("Invalid key " + entry.getKey() + ". Should be >= 0 and < " + dArr.length);
                }
                d += ((Double) entry.getValue()).doubleValue() * dArr[((Integer) entry.getKey()).intValue()];
            }
            return d;
        }
    }

    public PerceptronExtractor(Properties properties) throws IOException {
        Redwood.Util.log(new Object[]{"PerceptronExtractor configured with the following properties:"});
        this.epochs = Props.PERCEPTRON_EPOCHS;
        Redwood.Util.log(new Object[]{"epochs = " + this.epochs});
        this.softmaxEnabled = Props.PERCEPTRON_SOFTMAX;
        Redwood.Util.log(new Object[]{"softmaxEnabled = " + this.softmaxEnabled});
        Redwood.Util.log(new Object[]{"normType = " + properties.getProperty(Props.PERCEPTRON_NORMALIZE, "L2J")});
        this.modelType = Props.TRAIN_MODEL;
        Redwood.Util.log(new Object[]{"modelType = " + this.modelType});
        this.gamma = 1.0d;
        Redwood.Util.log(new Object[]{"gamma = " + this.gamma});
        this.verbose = false;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void save(ObjectOutputStream objectOutputStream) throws IOException {
        for (LabelWeights labelWeights : this.zWeights) {
            labelWeights.clear();
        }
        if (!$assertionsDisabled && this.zWeights == null) {
            throw new AssertionError();
        }
        objectOutputStream.writeInt(this.zWeights.length);
        for (LabelWeights labelWeights2 : this.zWeights) {
            objectOutputStream.writeObject(labelWeights2);
        }
        objectOutputStream.writeObject(this.labelIndex);
        objectOutputStream.writeObject(this.zFeatureIndex);
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        this.zWeights = new LabelWeights[objectInputStream.readInt()];
        for (int i = 0; i < this.zWeights.length; i++) {
            this.zWeights[i] = (LabelWeights) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        }
        this.labelIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.nilIndex = this.labelIndex.indexOf("_NR");
        this.zFeatureIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public TrainingStatistics train(KBPDataset<String, String> kBPDataset) {
        Redwood.Util.log(new Object[]{"Training the \"at least once\" model using " + kBPDataset.featureIndex().size() + " features and the following labels: " + kBPDataset.labelIndex.toString()});
        this.labelIndex = kBPDataset.labelIndex;
        this.labelIndex.add("_NR");
        this.nilIndex = this.labelIndex.indexOf("_NR");
        this.zFeatureIndex = kBPDataset.featureIndex();
        this.zWeights = new LabelWeights[this.labelIndex.size()];
        for (int i = 0; i < this.zWeights.length; i++) {
            this.zWeights[i] = new LabelWeights(kBPDataset.featureIndex().size());
        }
        int i2 = 0;
        for (int i3 = 0; i3 < this.epochs; i3++) {
            Redwood.Util.log(new Object[]{"Started epoch #" + i3 + "..."});
            kBPDataset.randomize(i3);
            this.posUpdateStats = new ClassicCounter();
            this.negUpdateStats = new ClassicCounter();
            this.unknownUpdateStats = new ClassicCounter();
            for (int i4 = 0; i4 < kBPDataset.size(); i4++) {
                int[][] iArr = kBPDataset.getDataArray()[i4];
                Set<Integer> set = kBPDataset.getPositiveLabelsArray()[i4];
                Set<Integer> set2 = kBPDataset.getNegativeLabelsArray()[i4];
                i2++;
                if (this.verbose) {
                    inputStats(i4, set, set2, iArr);
                }
                if (this.modelType == ModelType.AT_LEAST_ONCE_INC) {
                    trainJointlyOneGroupIncomplete(iArr, set, set2);
                } else if (this.modelType == ModelType.PERCEPTRON) {
                    trainLocallyOneGroup(iArr, set, set2, false);
                } else {
                    if (this.modelType != ModelType.PERCEPTRON_INC) {
                        throw new RuntimeException("Unsupported model type: " + this.modelType);
                    }
                    trainLocallyOneGroup(iArr, set, set2, true);
                }
                if (this.verbose) {
                    System.err.println("Group #" + i4 + " completed.");
                    System.err.println("=============================================================");
                }
                for (LabelWeights labelWeights : this.zWeights) {
                    labelWeights.updateSurvivalIterations();
                }
            }
            Redwood.Util.log(new Object[]{"Epoch #" + i3 + " completed. Inspected " + kBPDataset.size() + " datum groups. Performed " + this.posUpdateStats.getCount(-1) + " ++ updates and " + this.negUpdateStats.getCount(-1) + " -- updates and " + this.unknownUpdateStats.getCount(-1) + " unknown updates."});
            computeTrainingPerformance(kBPDataset);
        }
        Redwood.Util.log(new Object[]{"Run model through " + i2 + " iterations."});
        for (LabelWeights labelWeights2 : this.zWeights) {
            labelWeights2.normalize(i2);
        }
        printAvgVectors();
        this.statistics = Maybe.Just(TrainingStatistics.undefined());
        return TrainingStatistics.undefined();
    }

    private void trainLocallyOneGroup(int[][] iArr, Set<Integer> set, Set<Integer> set2, boolean z) {
        List<Counter<Integer>> estimateZ = estimateZ(iArr);
        for (int i = 0; i < iArr.length; i++) {
            int[] iArr2 = iArr[i];
            List descendingMagnitudeSortedListWithCounts = Counters.toDescendingMagnitudeSortedListWithCounts(estimateZ.get(i));
            int i2 = this.nilIndex;
            if (descendingMagnitudeSortedListWithCounts.size() > 0) {
                i2 = ((Integer) ((Pair) descendingMagnitudeSortedListWithCounts.get(0)).first()).intValue();
            }
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (intValue != this.nilIndex && intValue != i2) {
                    this.zWeights[intValue].update(iArr2, 1.0d);
                    if (this.verbose) {
                        System.err.println("Update +++ on label " + intValue);
                    }
                    this.posUpdateStats.incrementCount(Integer.valueOf(intValue));
                    this.posUpdateStats.incrementCount(-1);
                }
            }
            if (set.size() == 0 && i2 != this.nilIndex) {
                this.zWeights[this.nilIndex].update(iArr2, 1.0d);
                if (this.verbose) {
                    System.err.println("Update +++ on label NIL");
                }
                this.posUpdateStats.incrementCount(Integer.valueOf(this.nilIndex));
            }
            if (i2 != this.nilIndex && !set.contains(Integer.valueOf(i2)) && (!z || set2.contains(Integer.valueOf(i2)))) {
                this.zWeights[i2].update(iArr2, -1.0d);
                if (this.verbose) {
                    System.err.println("Update --- on label " + i2);
                }
                this.negUpdateStats.incrementCount(Integer.valueOf(i2));
                this.negUpdateStats.incrementCount(-1);
            }
            if (i2 == this.nilIndex && set.size() != 0) {
                this.zWeights[this.nilIndex].update(iArr2, -1.0d);
                if (this.verbose) {
                    System.err.println("Update --- on label NIL");
                }
                this.negUpdateStats.incrementCount(Integer.valueOf(i2));
            }
        }
    }

    private int[] generateZUnknown(int[] iArr, Set<Integer> set, Set<Integer> set2) {
        int[] removeGold = removeGold(removeGold(Arrays.copyOf(iArr, iArr.length), set2), set);
        if (this.verbose) {
            System.err.println("zUnknown after removing y+/-: " + arrayToString(removeGold));
        }
        return removeGold;
    }

    private void trainJointlyOneGroupIncomplete(int[][] iArr, Set<Integer> set, Set<Integer> set2) {
        List<Counter<Integer>> estimateZ = estimateZ(iArr);
        int[] generateZPredicted = generateZPredicted(estimateZ);
        if (this.verbose) {
            predictionZStats(generateZPredicted);
        }
        Counter<Integer> estimateY = estimateY(generateZPredicted);
        if (this.verbose) {
            predictionYStats(estimateY.keySet());
        }
        if (!updateCondition(estimateY.keySet(), set, set2)) {
            if (this.verbose) {
                System.err.println("No update necessary.");
                return;
            }
            return;
        }
        Set<Integer> computeUpdateY = computeUpdateY(estimateY.keySet(), set, set2);
        if (this.verbose) {
            System.err.println("yUpdate: " + computeUpdateY);
        }
        Set<Integer>[] generateZUpdate = generateZUpdate(generateZUnknown(generateZPredicted, set, set2), set, estimateZ);
        if (this.verbose) {
            System.err.println("zUpdate after adding y+: " + arrayToString(generateZUpdate));
        }
        HashSet hashSet = null;
        ArrayList arrayList = null;
        if (SOFT_UNKNOWN) {
            hashSet = new HashSet();
            for (Set<Integer> set3 : generateZUpdate) {
                for (Integer num : set3) {
                    if (!set.contains(num) && !set2.contains(num)) {
                        hashSet.add(num);
                    }
                }
            }
            if (this.verbose) {
                System.err.println("yUnknown: " + hashSet);
            }
            arrayList = new ArrayList();
            Iterator<Counter<Integer>> it = estimateZ.iterator();
            while (it.hasNext()) {
                arrayList.add(toProbabilities(it.next()));
            }
        }
        updateZModel(generateZUpdate, generateZPredicted, iArr, hashSet, arrayList);
    }

    private Counter<Integer> toProbabilities(Counter<Integer> counter) {
        ArrayList arrayList = new ArrayList();
        Iterator it = counter.keySet().iterator();
        while (it.hasNext()) {
            arrayList.add(Double.valueOf(counter.getCount((Integer) it.next())));
        }
        ClassicCounter classicCounter = new ClassicCounter();
        for (Integer num : counter.keySet()) {
            classicCounter.setCount(num, softmax(counter.getCount(num), arrayList, this.gamma));
        }
        return classicCounter;
    }

    private void printAvgVectors() {
        for (int i = 0; i < this.zWeights.length; i++) {
            System.err.print("AVG VECTOR #" + i + ":");
            for (int i2 = 0; i2 < this.zWeights[i].avgWeights.length; i2++) {
                double d = this.zWeights[i].avgWeights[i2];
                if (d != 0.0d) {
                    System.err.print(" " + i2 + ":" + d);
                }
            }
            System.err.println();
        }
    }

    private void computeTrainingPerformance(KBPDataset<String, String> kBPDataset) {
        ClassicCounter classicCounter = new ClassicCounter();
        ClassicCounter classicCounter2 = new ClassicCounter();
        ClassicCounter classicCounter3 = new ClassicCounter();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < kBPDataset.size(); i++) {
            int[][] iArr = kBPDataset.getDataArray()[i];
            Set<Integer> set = kBPDataset.getPositiveLabelsArray()[i];
            Counter<Integer> estimateY = estimateY(generateZPredicted(estimateZ(iArr)));
            if (this.verbose) {
                predictionYStats(estimateY.keySet());
            }
            for (Integer num : set) {
                hashSet.add(num);
                classicCounter.incrementCount(num);
                classicCounter.incrementCount(-1);
            }
            for (Integer num2 : estimateY.keySet()) {
                hashSet.add(num2);
                classicCounter2.incrementCount(num2);
                classicCounter2.incrementCount(-1);
                if (set.contains(num2)) {
                    classicCounter3.incrementCount(num2);
                    classicCounter3.incrementCount(-1);
                }
            }
        }
        Triple<Double, Double, Double> computeScore = computeScore(-1, classicCounter, classicCounter2, classicCounter3);
        Redwood.Util.log(new Object[]{"Overall score: P " + computeScore.first() + " R " + computeScore.second() + " F1 " + computeScore.third()});
    }

    private static Triple<Double, Double, Double> computeScore(int i, Counter<Integer> counter, Counter<Integer> counter2, Counter<Integer> counter3) {
        double count = counter.getCount(Integer.valueOf(i));
        double count2 = counter2.getCount(Integer.valueOf(i));
        double count3 = counter3.getCount(Integer.valueOf(i));
        double d = count2 > 0.0d ? count3 / count2 : 0.0d;
        double d2 = count > 0.0d ? count3 / count : 0.0d;
        return new Triple<>(Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d + d2 > 0.0d ? ((2.0d * d) * d2) / (d + d2) : 0.0d));
    }

    private int[] removeGold(int[] iArr, Set<Integer> set) {
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] != this.nilIndex && set.contains(Integer.valueOf(iArr[i]))) {
                iArr[i] = this.nilIndex;
            }
        }
        return iArr;
    }

    private Set<Integer>[] generateZUpdate(int[] iArr, Set<Integer> set, List<Counter<Integer>> list) {
        Set[] setArr = new Set[list.size()];
        for (int i = 0; i < setArr.length; i++) {
            setArr[i] = new HashSet();
        }
        HashSet<Integer> hashSet = new HashSet(set);
        HashSet hashSet2 = new HashSet();
        while (hashSet.size() > 0 && hashSet2.size() < list.size()) {
            int i2 = -1;
            double d = Double.MIN_VALUE;
            int i3 = -1;
            for (int i4 = 0; i4 < list.size(); i4++) {
                if (!hashSet2.contains(Integer.valueOf(i4))) {
                    for (Integer num : hashSet) {
                        double count = list.get(i4).getCount(num);
                        if (count > d) {
                            d = count;
                            i2 = i4;
                            i3 = num.intValue();
                        }
                    }
                }
            }
            if (i2 == -1) {
                break;
            }
            setArr[i2].add(Integer.valueOf(i3));
            hashSet2.add(Integer.valueOf(i2));
            hashSet.remove(Integer.valueOf(i3));
        }
        for (int i5 = 0; i5 < list.size(); i5++) {
            if (!hashSet2.contains(Integer.valueOf(i5))) {
                if (iArr == null || iArr[i5] == this.nilIndex) {
                    Iterator<Pair<Integer, Double>> it = sortPredictions(list.get(i5)).iterator();
                    while (true) {
                        if (it.hasNext()) {
                            Pair<Integer, Double> next = it.next();
                            if (set.contains(next.first())) {
                                setArr[i5].add(next.first());
                                hashSet.remove(next.first());
                                break;
                            }
                        }
                    }
                } else {
                    setArr[i5].add(Integer.valueOf(iArr[i5]));
                }
            }
        }
        if (hashSet.size() > 0) {
            for (int i6 = 0; i6 < list.size(); i6++) {
                if (!hashSet2.contains(Integer.valueOf(i6))) {
                    Iterator it2 = hashSet.iterator();
                    while (it2.hasNext()) {
                        setArr[i6].add(Integer.valueOf(((Integer) it2.next()).intValue()));
                    }
                }
            }
        }
        return setArr;
    }

    private static String arrayToString(Set<Integer>[] setArr) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < setArr.length; i++) {
            if (i > 0) {
                sb.append(" ");
            }
            sb.append(setArr[i]);
        }
        return sb.toString();
    }

    private static String arrayToString(int[] iArr) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < iArr.length; i++) {
            if (i > 0) {
                sb.append(" ");
            }
            sb.append(iArr[i]);
        }
        return sb.toString();
    }

    private void inputStats(int i, Set<Integer> set, Set<Integer> set2, int[][] iArr) {
        System.err.println("Group #" + i + " with " + iArr.length + " datums.");
        System.err.println("y+: " + set);
        System.err.println("y-: " + set2);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            System.err.print("Datum #" + i2 + ":");
            for (int i3 = 0; i3 < iArr[i2].length; i3++) {
                System.err.print(" " + iArr[i2][i3]);
            }
            System.err.println();
        }
        System.err.println();
    }

    private void predictionZStats(int[] iArr) {
        System.err.print("Predicted z:");
        for (int i : iArr) {
            System.err.print(" " + i);
        }
        System.err.println("\n");
    }

    private void predictionYStats(Set<Integer> set) {
        System.err.println("Predicted y: " + set);
    }

    private List<Counter<Integer>> estimateZ(int[][] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int[] iArr2 : iArr) {
            arrayList.add(estimateZ(iArr2));
        }
        return arrayList;
    }

    private Counter<Integer> estimateZ(int[] iArr) {
        Counter<Integer> classicCounter = new ClassicCounter<>();
        for (int i : iArr) {
            classicCounter.incrementCount(Integer.valueOf(i));
        }
        ClassicCounter classicCounter2 = new ClassicCounter();
        for (int i2 = 0; i2 < this.zWeights.length; i2++) {
            double dotProduct = this.zWeights[i2].dotProduct(classicCounter);
            if (dotProduct > 0.0d) {
                classicCounter2.setCount(Integer.valueOf(i2), dotProduct);
            }
        }
        return classicCounter2;
    }

    private static int pickBestLabel(Counter<Integer> counter) {
        if ($assertionsDisabled || counter.size() > 0) {
            return ((Integer) sortPredictions(counter).iterator().next().first()).intValue();
        }
        throw new AssertionError();
    }

    private int[] generateZPredicted(List<Counter<Integer>> list) {
        int[] iArr = new int[list.size()];
        for (int i = 0; i < list.size(); i++) {
            Counter<Integer> counter = list.get(i);
            int i2 = this.nilIndex;
            if (counter.size() > 0) {
                i2 = pickBestLabel(counter);
            }
            iArr[i] = i2;
        }
        return iArr;
    }

    private Counter<Integer> estimateY(int[] iArr) {
        return deterministicEstimateY(iArr);
    }

    private Counter<Integer> deterministicEstimateY(int[] iArr) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i : iArr) {
            if (i != this.nilIndex) {
                classicCounter.setCount(Integer.valueOf(i), 1.0d);
            }
        }
        return classicCounter;
    }

    private static Set<Integer> computeUpdateY(Set<Integer> set, Set<Integer> set2, Set<Integer> set3) {
        HashSet hashSet = new HashSet();
        for (Integer num : set2) {
            if (!set.contains(num) || !set3.contains(num)) {
                hashSet.add(num);
            }
        }
        for (Integer num2 : set) {
            if (!set3.contains(num2)) {
                hashSet.add(num2);
            }
        }
        return hashSet;
    }

    private static boolean updateCondition(Set<Integer> set, Set<Integer> set2, Set<Integer> set3) {
        Iterator<Integer> it = set2.iterator();
        while (it.hasNext()) {
            if (!set.contains(it.next())) {
                return true;
            }
        }
        Iterator<Integer> it2 = set3.iterator();
        while (it2.hasNext()) {
            if (set.contains(it2.next())) {
                return true;
            }
        }
        return false;
    }

    private void updateZModel(Set<Integer>[] setArr, int[] iArr, int[][] iArr2, Set<Integer> set, List<Counter<Integer>> list) {
        if (!$assertionsDisabled && setArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && iArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (this.verbose) {
            System.err.println("Updating model using:\n\tgoldZ = " + arrayToString(setArr) + "\n\tpredZ = " + arrayToString(iArr));
        }
        for (int i = 0; i < iArr2.length; i++) {
            Set<Integer> set2 = setArr[i];
            int i2 = iArr[i];
            int[] iArr3 = iArr2[i];
            if (i2 != this.nilIndex && !set2.contains(Integer.valueOf(i2))) {
                this.zWeights[i2].update(iArr3, -1.0d);
                if (this.verbose) {
                    System.err.println("Update --- on label " + i2);
                }
                this.negUpdateStats.incrementCount(Integer.valueOf(i2));
                this.negUpdateStats.incrementCount(-1);
            }
            if (i2 == this.nilIndex && set2.size() != 0) {
                this.zWeights[this.nilIndex].update(iArr3, -1.0d);
                if (this.verbose) {
                    System.err.println("Update --- on label NIL");
                }
                this.negUpdateStats.incrementCount(Integer.valueOf(i2));
            }
            if (SOFT_UNKNOWN && set != null && set.contains(Integer.valueOf(i2))) {
                this.zWeights[i2].update(iArr3, list.get(i).getCount(Integer.valueOf(i2)) - 1.0d);
                this.unknownUpdateStats.incrementCount(Integer.valueOf(i2));
                this.unknownUpdateStats.incrementCount(-1);
            }
            Iterator<Integer> it = set2.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (intValue != this.nilIndex && intValue != i2) {
                    this.zWeights[intValue].update(iArr3, 1.0d);
                    if (this.verbose) {
                        System.err.println("Update +++ on label " + intValue);
                    }
                    this.posUpdateStats.incrementCount(Integer.valueOf(intValue));
                    this.posUpdateStats.incrementCount(-1);
                }
            }
            if (set2.size() == 0 && i2 != this.nilIndex) {
                this.zWeights[this.nilIndex].update(iArr3, 1.0d);
                if (this.verbose) {
                    System.err.println("Update +++ on label NIL");
                }
                this.posUpdateStats.incrementCount(Integer.valueOf(this.nilIndex));
            }
        }
    }

    public Counter<String> classifyMentions(List<Collection<String>> list) {
        String[] strArr = new String[list.size()];
        Counter[] counterArr = (Counter[]) ErasureUtils.uncheckedCast(new Counter[list.size()]);
        ClassicCounter classicCounter = new ClassicCounter();
        ClassicCounter classicCounter2 = new ClassicCounter();
        for (int i = 0; i < list.size(); i++) {
            Counter<String> classifyLocally = classifyLocally(list.get(i));
            counterArr[i] = new ClassicCounter();
            for (String str : classifyLocally.keySet()) {
                counterArr[i].setCount(str, Math.log(classifyLocally.getCount(str)));
            }
            Pair<String, Double> pair = JointBayesRelationExtractor.sortPredictions(classifyLocally).get(0);
            String str2 = (String) pair.first();
            double doubleValue = ((Double) pair.second()).doubleValue();
            strArr[i] = str2;
            if (!strArr[i].equals("_NR")) {
                classicCounter.setCount(str2, (classicCounter.containsKey(str2) ? classicCounter.getCount(str2) : 1.0d) * (1.0d - doubleValue));
                if (doubleValue > classicCounter2.getCount(str2)) {
                    classicCounter2.setCount(str2, doubleValue);
                }
            }
        }
        ClassicCounter classicCounter3 = new ClassicCounter();
        for (String str3 : classicCounter.keySet()) {
            classicCounter3.setCount(str3, 1.0d - classicCounter.getCount(str3));
        }
        return classicCounter3;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        return RelationClassifier.firstProvenance(classifyMentions(RelationClassifier.tupleToFeatureList(sentenceGroup)), sentenceGroup);
    }

    private static List<Pair<Integer, Double>> sortPredictions(Counter<Integer> counter) {
        ArrayList arrayList = new ArrayList();
        for (Integer num : counter.keySet()) {
            arrayList.add(new Pair(num, Double.valueOf(counter.getCount(num))));
        }
        sortPredictions(arrayList);
        return arrayList;
    }

    private static void sortPredictions(List<Pair<Integer, Double>> list) {
        Collections.sort(list, (pair, pair2) -> {
            if (((Double) pair.second()).doubleValue() > ((Double) pair2.second()).doubleValue()) {
                return -1;
            }
            if (!((Double) pair.second()).equals(pair2.second())) {
                return 1;
            }
            if (((Integer) pair.first()).intValue() > ((Integer) pair2.first()).intValue()) {
                return -1;
            }
            return ((Integer) pair.first()).equals(pair2.first()) ? 0 : 1;
        });
    }

    private Counter<String> classifyLocally(Collection<String> collection) {
        ClassicCounter classicCounter = new ClassicCounter();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.zWeights.length; i++) {
            double avgDotProduct = this.zWeights[i].avgDotProduct(collection, this.zFeatureIndex);
            classicCounter.setCount(Integer.valueOf(i), avgDotProduct);
            arrayList.add(Double.valueOf(avgDotProduct));
        }
        ClassicCounter classicCounter2 = new ClassicCounter();
        for (Integer num : classicCounter.keySet()) {
            classicCounter2.setCount((String) this.labelIndex.get(num.intValue()), softmax(classicCounter.getCount(num), arrayList, this.gamma));
        }
        return classicCounter2;
    }

    public static PerceptronExtractor load(String str, Properties properties) throws IOException, ClassNotFoundException {
        return (PerceptronExtractor) RelationClassifier.load(str, properties, PerceptronExtractor.class);
    }

    static {
        $assertionsDisabled = !PerceptronExtractor.class.desiredAssertionStatus();
        SOFT_UNKNOWN = false;
    }
}
