package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Pointer;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.common.Utils;
import edu.stanford.nlp.kbp.slotfilling.classify.TrainingStatistics;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPRelationProvenance;
import edu.stanford.nlp.kbp.slotfilling.train.KBPTrainer;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.BoundedPriorityQueue;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MetaClass;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Sets;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor.class */
public class JointBayesRelationExtractor extends RelationClassifier {
    private static final long serialVersionUID = -7961154075748697901L;
    private static final Redwood.RedwoodChannels logger;
    private final boolean partOfEnsemble;
    private static final LOCAL_CLASSIFICATION_MODE localClassificationMode;
    protected LinearClassifier<String, String>[] zClassifiers;
    LinearClassifier<String, String> zSingleClassifier;
    protected Map<String, LinearClassifier<String, String>> yClassifiers;
    private Index<String> featureIndex;
    private Index<String> yLabelIndex;
    protected Index<String> zLabelIndex;
    private static String ATLEASTONCE_FEAT;
    private static String NONE_FEAT;
    private static String UNIQUE_FEAT;
    private static String SIGMOID_FEAT;
    private static List<String> Y_FEATURES_FOR_INITIAL_MODEL;
    private final int numberOfTrainEpochs;
    protected int numberOfFolds;
    private final boolean onlyLocalTraining;
    private final String initialModelPath;
    private final double zSigma;
    private final double ySigma;
    private AtomicInteger zUpdatesInOneEpoch;
    private final LocalFilter localDataFilter;
    private final InferenceType inferenceType;
    private final boolean trainY;
    protected Set<String> knownDependencies;
    private String serializedModelPath;
    private int numberOfThreads;
    private final KBPTrainer.MinimizerType zClassifierMinimizerType;
    private final Object lock;
    private static final double BIG_WEIGHT = 10.0d;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$AllFilter.class */
    public static class AllFilter extends LocalFilter {
        @Override // edu.stanford.nlp.kbp.slotfilling.classify.JointBayesRelationExtractor.LocalFilter
        public boolean filterZ(int[][] iArr, Set<Integer> set) {
            return true;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$InferenceType.class */
    public enum InferenceType {
        SLOW,
        STABLE
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$Initialization.class */
    public enum Initialization {
        DISTSUP,
        SUPERVISED
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$LOCAL_CLASSIFICATION_MODE.class */
    public enum LOCAL_CLASSIFICATION_MODE {
        WEIGHTED_VOTE,
        SINGLE_MODEL
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$LargeFilter.class */
    public static class LargeFilter extends LocalFilter {
        final int threshold;

        public LargeFilter(int i) {
            this.threshold = i;
        }

        @Override // edu.stanford.nlp.kbp.slotfilling.classify.JointBayesRelationExtractor.LocalFilter
        public boolean filterZ(int[][] iArr, Set<Integer> set) {
            return iArr.length <= this.threshold;
        }

        @Override // edu.stanford.nlp.kbp.slotfilling.classify.JointBayesRelationExtractor.LocalFilter
        public boolean filterY(int[][] iArr, Set<Integer> set) {
            return iArr.length <= this.threshold;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$LocalFilter.class */
    public static abstract class LocalFilter {
        public abstract boolean filterZ(int[][] iArr, Set<Integer> set);

        public boolean filterY(int[][] iArr, Set<Integer> set) {
            return true;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$RedundancyFilter.class */
    public static class RedundancyFilter extends LocalFilter {
        @Override // edu.stanford.nlp.kbp.slotfilling.classify.JointBayesRelationExtractor.LocalFilter
        public boolean filterZ(int[][] iArr, Set<Integer> set) {
            return set.size() <= 1 && iArr.length > 1;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/JointBayesRelationExtractor$SingleFilter.class */
    public static class SingleFilter extends LocalFilter {
        @Override // edu.stanford.nlp.kbp.slotfilling.classify.JointBayesRelationExtractor.LocalFilter
        public boolean filterZ(int[][] iArr, Set<Integer> set) {
            return set.size() <= 1;
        }
    }

    public JointBayesRelationExtractor(Properties properties) {
        this(properties, false);
    }

    public JointBayesRelationExtractor(Properties properties, boolean z) {
        String str;
        this.partOfEnsemble = Props.TRAIN_MODEL == ModelType.ENSEMBLE;
        this.zUpdatesInOneEpoch = new AtomicInteger(0);
        this.numberOfThreads = -1;
        this.lock = "I'm a lock :)";
        str = "kbp_relation_model";
        String substring = str.endsWith(Props.SER_EXT) ? str.substring(0, str.length() - Props.SER_EXT.length()) : "kbp_relation_model";
        double d = Props.TRAIN_NEGATIVES_RATIO;
        if (Props.TRAIN_JOINTBAYES_LOADINITMODEL_FILE != null) {
            this.initialModelPath = Props.TRAIN_JOINTBAYES_LOADINITMODEL_FILE;
        } else {
            this.initialModelPath = makeInitialModelPath(Props.KBP_MODEL_DIR.getPath(), substring, ModelType.JOINT_BAYES, d);
        }
        this.numberOfTrainEpochs = Props.TRAIN_JOINTBAYES_EPOCHS;
        this.numberOfFolds = Props.TRAIN_JOINTBAYES_FOLDS;
        if (this.numberOfFolds < 2) {
            throw new IllegalArgumentException("Must have at least two folds: " + this.numberOfFolds);
        }
        this.numberOfThreads = Props.TRAIN_JOINTBAYES_MULTITHREAD ? Execution.threads : 1;
        this.zSigma = Props.TRAIN_JOINTBAYES_ZSIGMA;
        this.ySigma = 1.0d;
        this.localDataFilter = (LocalFilter) MetaClass.create(Props.TRAIN_JOINTBAYES_FILTER).createInstance(new Object[0]);
        this.inferenceType = Props.TRAIN_JOINTBAYES_INFERENCETYPE;
        this.trainY = Props.TRAIN_JOINTBAYES_TRAINY;
        this.onlyLocalTraining = z;
        this.serializedModelPath = makeModelPath(Props.KBP_MODEL_DIR.getPath(), substring, ModelType.JOINT_BAYES, d);
        this.zClassifierMinimizerType = Props.TRAIN_JOINTBAYES_ZMINIMIZER;
        Redwood.Util.log(new Object[]{Redwood.Util.BLUE, "y features: " + StringUtils.join(Props.TRAIN_JOINTBAYES_YFEATURES, " | ")});
    }

    private static String makeInitialModelPath(String str, String str2, ModelType modelType, double d) {
        return str + File.separator + str2 + "." + modelType + "." + ((int) (100.0d * d)) + ".initial" + Props.SER_EXT;
    }

    private static String makeModelPath(String str, String str2, ModelType modelType, double d) {
        return str + File.separator + str2 + "." + modelType + "." + ((int) (100.0d * d)) + Props.SER_EXT;
    }

    private int foldStart(int i, int i2) {
        int i3 = i2 / this.numberOfFolds;
        if (!$assertionsDisabled && i3 <= 0) {
            throw new AssertionError();
        }
        int i4 = i * i3;
        if ($assertionsDisabled || i4 < i2) {
            return i4;
        }
        throw new AssertionError();
    }

    private int foldEnd(int i, int i2) {
        if (i == this.numberOfFolds - 1) {
            return i2;
        }
        int i3 = i2 / this.numberOfFolds;
        if (!$assertionsDisabled && i3 <= 0) {
            throw new AssertionError();
        }
        int i4 = (i + 1) * i3;
        if ($assertionsDisabled || i4 <= i2) {
            return i4;
        }
        throw new AssertionError();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
    private int[][] initializeZLabels(KBPDataset<String, String> kBPDataset) {
        ?? r0 = new int[kBPDataset.getDataArray().length];
        for (int i = 0; i < this.numberOfFolds; i++) {
            LinearClassifier<String, String> linearClassifier = this.zClassifiers[i];
            if (!$assertionsDisabled && linearClassifier == null) {
                throw new AssertionError();
            }
            for (int foldStart = foldStart(i, kBPDataset.getDataArray().length); foldStart < foldEnd(i, kBPDataset.getDataArray().length); foldStart++) {
                int[][] iArr = kBPDataset.getDataArray()[foldStart];
                r0[foldStart] = new int[iArr.length];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    int[] iArr2 = iArr[i2];
                    for (int i3 : iArr2) {
                        if (i3 >= this.featureIndex.size()) {
                            logger.log(new Object[]{i3 + "\t" + this.featureIndex.size()});
                        }
                        if (!$assertionsDisabled && i3 >= this.featureIndex.size()) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && i3 >= linearClassifier.featureIndex().size()) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && i3 < 0) {
                            throw new AssertionError();
                        }
                    }
                    List<Pair<String, Double>> sortPredictions = sortPredictions((Counter<String>) linearClassifier.scoresOf(iArr2));
                    int indexOf = this.zLabelIndex.indexOf(sortPredictions.get(0).first());
                    if (indexOf < 0) {
                        throw new IllegalStateException("Unknown relation: " + ((String) sortPredictions.get(0).first()));
                    }
                    if (!$assertionsDisabled && indexOf == -1) {
                        throw new AssertionError();
                    }
                    r0[foldStart][i2] = indexOf;
                }
            }
        }
        return r0;
    }

    private void detectDependencyYFeatures(KBPDataset<String, String> kBPDataset) {
        this.knownDependencies = new HashSet();
        for (int i = 0; i < kBPDataset.size(); i++) {
            Set<Integer> set = kBPDataset.getPositiveLabelsArray()[i];
            for (Integer num : set) {
                String str = (String) kBPDataset.labelIndex().get(num.intValue());
                for (Integer num2 : set) {
                    if (num.intValue() != num2.intValue()) {
                        String makeCoocurrenceFeature = makeCoocurrenceFeature(str, (String) kBPDataset.labelIndex().get(num2.intValue()));
                        logger.debug(new Object[]{"FOUND COOC: " + makeCoocurrenceFeature});
                        this.knownDependencies.add(makeCoocurrenceFeature);
                    }
                }
            }
        }
    }

    private static String makeCoocurrenceFeature(String str, String str2) {
        return "co:s|" + str + "|d|" + str2 + "|";
    }

    private Runnable createZClassifierTrainer(LinearClassifierFactory<String, String> linearClassifierFactory, Dataset<String, String> dataset, int i, int i2) {
        int[][] dataArray = dataset.getDataArray();
        int[] labelsArray = dataset.getLabelsArray();
        double[][] weights = this.zClassifiers[i2] != null ? this.zClassifiers[i2].weights() : (double[][]) null;
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i3 : labelsArray) {
            classicCounter.incrementCount(this.zLabelIndex.get(i3));
        }
        if (!this.partOfEnsemble || !Props.TRAIN_JOINTBAYES_MULTITHREAD) {
            Redwood.Util.startTrack(new Object[]{"Relation Label Distribution"});
        }
        for (Pair pair : Counters.toSortedListWithCounts(classicCounter)) {
            logger.log(new Object[]{"" + new DecimalFormat("000.00%").format(((Double) pair.second).doubleValue() / classicCounter.totalCount()) + ": " + ((String) pair.first)});
        }
        if (!this.partOfEnsemble || !Props.TRAIN_JOINTBAYES_MULTITHREAD) {
            Redwood.Util.endTrack("Relation Label Distribution");
        }
        return () -> {
            String str = "EPOCH " + i + ": Training Z classifier for fold #" + i2;
            if (!this.partOfEnsemble || !Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                Redwood.Util.startTrack(new Object[]{str});
            }
            LinearClassifier<String, String> trainClassifierWithInitialWeights = linearClassifierFactory.trainClassifierWithInitialWeights(new Dataset(this.zLabelIndex, makeTrainLabelArrayForFold(labelsArray, i2), this.featureIndex, makeTrainDataArrayForFold(dataArray, i2)), weights);
            synchronized (this.lock) {
                this.zClassifiers[i2] = trainClassifierWithInitialWeights;
            }
            if (this.partOfEnsemble && Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                Redwood.finishThread();
            } else {
                Redwood.endTrack(str);
            }
        };
    }

    private Runnable createYClassifierTrainer(LinearClassifierFactory<String, String> linearClassifierFactory, Map<String, RVFDataset<String, String>> map, String str, int i) {
        RVFDataset<String, String> rVFDataset = map.get(str);
        if (rVFDataset.size() == 0 && !Props.JUNIT) {
            logger.debug(new Object[]{"Empty train set.  yLabel=" + str});
            throw new RuntimeException("[JointBayesRelationExtractor.createYClassifierTrainer] Empty train set.  yLabel=" + str);
        }
        int[] labelsArray = rVFDataset.getLabelsArray();
        boolean z = false;
        int i2 = 0;
        while (true) {
            if (i2 >= labelsArray.length - 1) {
                break;
            }
            if (labelsArray[i2] != labelsArray[i2 + 1]) {
                z = true;
                break;
            }
            i2++;
        }
        if (!z) {
            Redwood.RedwoodChannels redwoodChannels = logger;
            Object[] objArr = new Object[1];
            objArr[0] = "Train set all same value.  val=" + (labelsArray.length > 0 ? Integer.valueOf(labelsArray[0]) : "none") + "  yLabel=" + str;
            redwoodChannels.debug(objArr);
        }
        return () -> {
            String str2 = "EPOCH " + i + ": Training Y classifier for label " + str;
            if (!this.partOfEnsemble || !Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                Redwood.Util.startTrack(new Object[]{str2});
            }
            LinearClassifier<String, String> trainClassifier = linearClassifierFactory.trainClassifier(rVFDataset);
            synchronized (this.lock) {
                this.yClassifiers.put(str, trainClassifier);
            }
            if (this.partOfEnsemble && Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                Redwood.finishThread();
            } else {
                Redwood.Util.endTrack(str2);
            }
        };
    }

    private Runnable createZLabeller(LinearClassifier<String, String> linearClassifier, Map<String, RVFDataset<String, String>> map, KBPDataset<String, String> kBPDataset, int[][] iArr, int[][] iArr2, int i, int i2, Pointer<Triple<int[], Counter<String>[], double[]>> pointer) {
        int[][] iArr3 = kBPDataset.getDataArray()[i2];
        Set<Integer> set = kBPDataset.getPositiveLabelsArray()[i2];
        Set<Integer> set2 = kBPDataset.getNegativeLabelsArray()[i2];
        int[] iArr4 = iArr[i2];
        int[] iArr5 = iArr2[i2];
        Maybe<String>[] annotatedLabels = kBPDataset.getAnnotatedLabels(i2);
        if ($assertionsDisabled || iArr3.length == annotatedLabels.length) {
            return () -> {
                Counter<String>[] counterArr = (Counter[]) ErasureUtils.uncheckedCast(new Counter[iArr3.length]);
                synchronized (iArr3) {
                    int[] randomizeGroup = randomizeGroup(iArr3, annotatedLabels, i);
                    predictZLabels(iArr3, iArr4, linearClassifier);
                    switch (this.inferenceType) {
                        case SLOW:
                            pointer.set((Pointer) Triple.makeTriple(randomizeGroup, counterArr, inferZLabels(iArr3, set, set2, iArr5, annotatedLabels, counterArr, linearClassifier, i)));
                            break;
                        case STABLE:
                            pointer.set((Pointer) Triple.makeTriple(randomizeGroup, counterArr, inferZLabelsStable(iArr3, set, set2, iArr5, annotatedLabels, counterArr, linearClassifier, i)));
                            break;
                        default:
                            throw new RuntimeException("ERROR: unknown inference type: " + this.inferenceType);
                    }
                    synchronized (this.lock) {
                        Iterator it = set.iterator();
                        while (it.hasNext()) {
                            String str = (String) this.yLabelIndex.get(((Integer) it.next()).intValue());
                            addYDatum((RVFDataset) map.get(str), str, iArr5, counterArr, true);
                        }
                        Iterator it2 = set2.iterator();
                        while (it2.hasNext()) {
                            String str2 = (String) this.yLabelIndex.get(((Integer) it2.next()).intValue());
                            addYDatum((RVFDataset) map.get(str2), str2, iArr5, counterArr, false);
                        }
                    }
                }
            };
        }
        throw new AssertionError();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v97, types: [int[], int[][]] */
    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public TrainingStatistics train(KBPDataset<String, String> kBPDataset) {
        if (this.numberOfThreads <= 0) {
            this.numberOfThreads = Runtime.getRuntime().availableProcessors();
        }
        logger.log(new Object[]{"Number of threads is " + this.numberOfThreads});
        Redwood.Util.forceTrack("Filtering data");
        if (this.localDataFilter instanceof LargeFilter) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            for (int i = 0; i < kBPDataset.size(); i++) {
                if (this.localDataFilter.filterY(kBPDataset.getDataArray()[i], kBPDataset.getPositiveLabelsArray()[i])) {
                    arrayList.add(kBPDataset.getDataArray()[i]);
                    arrayList3.add(kBPDataset.getPositiveLabelsArray()[i]);
                    arrayList4.add(kBPDataset.getNegativeLabelsArray()[i]);
                    arrayList5.add(kBPDataset.getUnknownLabelsArray()[i]);
                    arrayList6.add(kBPDataset.getAnnotatedLabels(i));
                    arrayList2.add(kBPDataset.getSentenceGlossKey(i));
                }
            }
            kBPDataset = new KBPDataset<>((int[][][]) arrayList.toArray(new int[arrayList.size()]), kBPDataset.featureIndex(), kBPDataset.labelIndex, (Set[]) arrayList3.toArray((Object[]) ErasureUtils.uncheckedCast(new Set[arrayList3.size()])), (Set[]) arrayList4.toArray((Object[]) ErasureUtils.uncheckedCast(new Set[arrayList4.size()])), (Set[]) arrayList5.toArray((Object[]) ErasureUtils.uncheckedCast(new Set[arrayList5.size()])), (Maybe[][]) arrayList6.toArray((Object[]) ErasureUtils.uncheckedCast(new Maybe[arrayList6.size()])), (String[][]) arrayList2.toArray(new String[arrayList2.size()]));
        }
        Redwood.Util.endTrack("Filtering data");
        LinearClassifierFactory<String, String> zClassifierFactory = getZClassifierFactory();
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(1.0E-4d, false, this.ySigma);
        zClassifierFactory.setVerbose(false);
        linearClassifierFactory.setVerbose(false);
        logger.log(new Object[]{"Created classifiers"});
        Redwood.Util.startTrack(new Object[]{"Initializing classifiers"});
        boolean z = true;
        if (this.initialModelPath != null && new File(this.initialModelPath).exists() && Props.TRAIN_JOINTBAYES_LOADINITMODEL) {
            try {
                loadInitialModels(this.initialModelPath);
                if (kBPDataset.featureIndex().size() > this.featureIndex.size()) {
                    logger.warn(new Object[]{Redwood.Util.RED, "Loaded an initial model with fewer features than the dataset! Ignoring..."});
                } else {
                    z = false;
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } else if (Props.TRAIN_JOINTBAYES_LOADINITMODEL) {
            if (this.initialModelPath == null) {
                logger.warn(new Object[]{Redwood.Util.RED, "Cannot load initial model: no initial model path"});
            } else if (!new File(this.initialModelPath).exists()) {
                logger.warn(new Object[]{Redwood.Util.RED, "Cannot load initial model: model does not exist at " + this.initialModelPath});
            }
        }
        if (z) {
            this.featureIndex = kBPDataset.featureIndex();
            this.yLabelIndex = kBPDataset.labelIndex;
            this.zLabelIndex = new HashIndex(this.yLabelIndex);
            this.zLabelIndex.add("_NR");
            this.zClassifiers = initializeZClassifierLocally(kBPDataset, this.featureIndex, this.zLabelIndex);
            this.yClassifiers = initializeYClassifiersWithAtLeastOnce(this.yLabelIndex);
            if (this.initialModelPath != null) {
                try {
                    saveInitialModels(this.initialModelPath);
                } catch (IOException e2) {
                    logger.err(new Object[]{Redwood.Util.RED, "Could not save initial model: " + e2.getMessage()});
                }
            }
        }
        Redwood.Util.endTrack("Initializing classifiers");
        if (this.onlyLocalTraining) {
            return TrainingStatistics.undefined();
        }
        detectDependencyYFeatures(kBPDataset);
        for (String str : this.yLabelIndex) {
            int indexOf = this.yLabelIndex.indexOf(str);
            if (indexOf < 0) {
                throw new IllegalStateException("Unknown relation: " + str);
            }
            logger.log(new Object[]{"YLABELINDEX " + str + " = " + indexOf});
        }
        int i2 = 0;
        for (int[][] iArr : kBPDataset.getDataArray()) {
            i2 += iArr.length;
        }
        int[][] initializeZLabels = initializeZLabels(kBPDataset);
        computeConfusionMatrixForCounts("LOCAL", initializeZLabels, kBPDataset.getPositiveLabelsArray());
        computeYScore("LOCAL", initializeZLabels, kBPDataset.getPositiveLabelsArray());
        Map<String, RVFDataset<String, String>> initializeYDatasets = initializeYDatasets();
        TrainingStatistics empty = TrainingStatistics.empty();
        Pair[] pairArr = new Pair[kBPDataset.size()];
        boolean z2 = Props.TRAIN_UNLABELED;
        if (z2) {
            kBPDataset.finalizeLabels();
        }
        Redwood.Util.startTrack(new Object[]{"EM"});
        int i3 = 0;
        while (true) {
            if (i3 >= this.numberOfTrainEpochs) {
                break;
            }
            this.zUpdatesInOneEpoch = new AtomicInteger(0);
            logger.log(new Object[]{"***EPOCH " + i3 + "***"});
            ?? r0 = new int[initializeZLabels.length];
            for (int i4 = 0; i4 < initializeZLabels.length; i4++) {
                r0[i4] = new int[initializeZLabels[i4].length];
            }
            Redwood.Util.forceTrack("E-Step");
            if (z2 && i3 > 0) {
                Redwood.Util.forceTrack("Guessing Y labels");
                kBPDataset.restoreLabels();
                Set[] positiveLabelsArray = kBPDataset.getPositiveLabelsArray();
                Set<Integer>[] negativeLabelsArray = kBPDataset.getNegativeLabelsArray();
                Set<Integer>[] unknownLabelsArray = kBPDataset.getUnknownLabelsArray();
                int countLabels = kBPDataset.countLabels(positiveLabelsArray);
                int countLabels2 = kBPDataset.countLabels(negativeLabelsArray);
                int countLabels3 = kBPDataset.countLabels(unknownLabelsArray);
                int[][][] dataArray = kBPDataset.getDataArray();
                int length = (int) (Props.TRAIN_JOINTBAYES_PERCENT_POSITIVE * unknownLabelsArray.length * kBPDataset.labelIndex.size());
                int i5 = length - countLabels;
                Redwood.Util.log(new Object[]{"Before relabeling: " + countLabels + " positive, " + countLabels2 + " negative, " + countLabels3 + " unknown"});
                Redwood.Util.log(new Object[]{"Relabeling parameters: " + Props.TRAIN_JOINTBAYES_PERCENT_POSITIVE + " theta, " + unknownLabelsArray.length + " groups, " + kBPDataset.labelIndex().size() + " labels"});
                if (i5 > 0) {
                    Redwood.Util.log(new Object[]{"Target " + length + " positive, need to change " + i5 + " unknown"});
                    BoundedPriorityQueue boundedPriorityQueue = new BoundedPriorityQueue(i5, (triple, triple2) -> {
                        return triple.compareTo(triple2);
                    });
                    if (this.zSingleClassifier != null) {
                        for (int i6 = 0; i6 < unknownLabelsArray.length; i6++) {
                            Counter<Integer> computeYLogProbs = computeYLogProbs(this.zSingleClassifier, dataArray[i6], Sets.diff(unknownLabelsArray[i6], positiveLabelsArray[i6]));
                            Iterator it = computeYLogProbs.keySet().iterator();
                            while (it.hasNext()) {
                                int intValue = ((Integer) it.next()).intValue();
                                boundedPriorityQueue.add(Triple.makeTriple(Double.valueOf(computeYLogProbs.getCount(Integer.valueOf(intValue))), Integer.valueOf(i6), Integer.valueOf(intValue)));
                            }
                        }
                    } else {
                        for (int i7 = 0; i7 < this.numberOfFolds; i7++) {
                            int foldStart = foldStart(i7, kBPDataset.getDataArray().length);
                            int foldEnd = foldEnd(i7, kBPDataset.getDataArray().length);
                            for (int i8 = foldStart; i8 < foldEnd; i8++) {
                                Counter<Integer> computeYLogProbs2 = computeYLogProbs(this.zClassifiers[i7], dataArray[i8], Sets.diff(unknownLabelsArray[i8], positiveLabelsArray[i8]));
                                Iterator it2 = computeYLogProbs2.keySet().iterator();
                                while (it2.hasNext()) {
                                    int intValue2 = ((Integer) it2.next()).intValue();
                                    boundedPriorityQueue.add(Triple.makeTriple(Double.valueOf(computeYLogProbs2.getCount(Integer.valueOf(intValue2))), Integer.valueOf(i8), Integer.valueOf(intValue2)));
                                }
                            }
                        }
                    }
                    Iterator<E> it3 = boundedPriorityQueue.iterator();
                    while (it3.hasNext()) {
                        Triple triple3 = (Triple) it3.next();
                        logger.debug(new Object[]{"Relabel datum " + triple3.second + " as belonging to " + ((String) kBPDataset.labelIndex().get(((Integer) triple3.third).intValue())) + ": logProb " + triple3.first});
                        positiveLabelsArray[((Integer) triple3.second).intValue()].add(triple3.third);
                        negativeLabelsArray[((Integer) triple3.second).intValue()].remove(triple3.third);
                    }
                    countLabels = kBPDataset.countLabels(positiveLabelsArray);
                    Redwood.Util.log(new Object[]{"After relabeling: " + countLabels + " positive, " + kBPDataset.countLabels(negativeLabelsArray) + " negative, " + boundedPriorityQueue.size() + " changed"});
                } else {
                    Redwood.Util.log(new Object[]{"No relabeling: target of " + length + " reached"});
                }
                for (int i9 = 0; i9 < unknownLabelsArray.length; i9++) {
                    Iterator<Integer> it4 = unknownLabelsArray[i9].iterator();
                    while (it4.hasNext()) {
                        int intValue3 = it4.next().intValue();
                        if (!positiveLabelsArray[i9].contains(Integer.valueOf(intValue3))) {
                            negativeLabelsArray[i9].add(Integer.valueOf(intValue3));
                        }
                    }
                }
                Redwood.Util.log(new Object[]{"After marking unknowns negative: " + countLabels + " positive, " + kBPDataset.countLabels(negativeLabelsArray) + " negative"});
                Redwood.Util.endTrack("Guessing Y labels");
            }
            for (int i10 = 0; i10 < this.numberOfFolds; i10++) {
                LinearClassifier<String, String> linearClassifier = this.zClassifiers[i10];
                int foldStart2 = foldStart(i10, kBPDataset.getDataArray().length);
                int foldEnd2 = foldEnd(i10, kBPDataset.getDataArray().length);
                ArrayList arrayList7 = new ArrayList();
                Pointer[] pointerArr = new Pointer[foldEnd2 - foldStart2];
                for (int i11 = foldStart2; i11 < foldEnd2; i11++) {
                    pointerArr[i11 - foldStart2] = new Pointer();
                    arrayList7.add(createZLabeller(linearClassifier, initializeYDatasets, kBPDataset, r0, initializeZLabels, i3, i11, pointerArr[i11 - foldStart2]));
                }
                Redwood.Util.threadAndRun("EPOCH " + i3 + ": Inferring hidden sentence labels Z_i's", arrayList7, this.numberOfThreads);
                Redwood.Util.startTrack(new Object[]{"Updating training statistics"});
                for (int i12 = foldStart2; i12 < foldEnd2; i12++) {
                    Triple triple4 = (Triple) pointerArr[i12 - foldStart2].dereference().orCrash();
                    int[] iArr2 = (int[]) triple4.first;
                    Counter[] counterArr = new Counter[iArr2.length];
                    double[] dArr = new double[iArr2.length];
                    for (int i13 = 0; i13 < iArr2.length; i13++) {
                        counterArr[iArr2[i13]] = ((Counter[]) triple4.second)[i13];
                        dArr[iArr2[i13]] = ((double[]) triple4.third)[i13];
                    }
                    pairArr[i12] = Pair.makePair(counterArr, dArr);
                }
                Redwood.Util.endTrack("Updating training statistics");
            }
            computeConfusionMatrixForCounts("EPOCH " + i3, initializeZLabels, kBPDataset.getPositiveLabelsArray());
            computeConfusionMatrixForCounts("(Z ONLY) EPOCH " + i3, r0, kBPDataset.getPositiveLabelsArray());
            computeYScore("EPOCH " + i3, initializeZLabels, kBPDataset.getPositiveLabelsArray());
            computeYScore("(Z ONLY) EPOCH " + i3, r0, kBPDataset.getPositiveLabelsArray());
            logger.log(new Object[]{"In epoch #" + i3 + " zUpdatesInOneEpoch = " + this.zUpdatesInOneEpoch});
            if (this.zUpdatesInOneEpoch.get() == 0) {
                logger.log(new Object[]{"Stopping training. Did not find any changes in the Z labels!"});
                Redwood.Util.endTrack("E-Step");
                break;
            }
            Dataset<String, String> initializeZDataset = initializeZDataset(i2, initializeZLabels, kBPDataset.getDataArray());
            Redwood.Util.endTrack("E-Step");
            Redwood.Util.startTrack(new Object[]{"M-STEP"});
            ArrayList arrayList8 = new ArrayList();
            for (int i14 = 0; i14 < this.numberOfFolds; i14++) {
                arrayList8.add(createZClassifierTrainer(zClassifierFactory, initializeZDataset, i3, i14));
            }
            if (this.partOfEnsemble && Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                Redwood.Util.log(new Object[]{"EPOCH " + i3 + ": Training Z classifiers"});
                ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Math.max(1, Execution.threads / Props.TRAIN_ENSEMBLE_NUMCOMPONENTS));
                Iterator it5 = arrayList8.iterator();
                while (it5.hasNext()) {
                    newFixedThreadPool.submit((Runnable) it5.next());
                }
                newFixedThreadPool.shutdown();
                try {
                    newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
                } catch (InterruptedException e3) {
                    throw new RuntimeException(e3);
                }
            } else {
                Redwood.Util.threadAndRun("EPOCH " + i3 + ": Training Z classifiers", arrayList8, this.numberOfThreads);
            }
            if (this.trainY) {
                ArrayList arrayList9 = new ArrayList();
                Iterator it6 = this.yLabelIndex.iterator();
                while (it6.hasNext()) {
                    arrayList9.add(createYClassifierTrainer(linearClassifierFactory, initializeYDatasets, (String) it6.next(), i3));
                }
                if (this.partOfEnsemble && Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                    Redwood.Util.log(new Object[]{"EPOCH " + i3 + ": Training Y classifiers"});
                    ExecutorService newFixedThreadPool2 = Executors.newFixedThreadPool(Math.max(1, Execution.threads / Props.TRAIN_ENSEMBLE_NUMCOMPONENTS));
                    Iterator it7 = arrayList9.iterator();
                    while (it7.hasNext()) {
                        newFixedThreadPool2.submit((Runnable) it7.next());
                    }
                    newFixedThreadPool2.shutdown();
                    try {
                        newFixedThreadPool2.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
                    } catch (InterruptedException e4) {
                        throw new RuntimeException(e4);
                    }
                } else {
                    Redwood.Util.threadAndRun("EPOCH " + i3 + ": Training Y classifiers", arrayList9, this.numberOfThreads);
                }
            }
            makeSingleZClassifier(initializeZDataset, zClassifierFactory);
            String makeEpochPath = makeEpochPath(i3);
            if (makeEpochPath != null) {
                try {
                    save(makeEpochPath);
                } catch (IOException e5) {
                    logger.err(new Object[]{Redwood.Util.RED, "WARNING: could not save model of epoch " + i3 + " to path: " + makeEpochPath});
                    logger.err(new Object[]{Redwood.Util.RED, "Exception message: " + e5.getMessage()});
                }
            }
            initializeYDatasets = initializeYDatasets();
            Redwood.Util.endTrack("M-STEP");
            i3++;
        }
        Redwood.Util.endTrack("EM");
        makeSingleZClassifier(initializeZDataset(i2, initializeZLabels, kBPDataset.getDataArray()), zClassifierFactory);
        Redwood.Util.startTrack(new Object[]{"Computing Statistics"});
        try {
            for (int i15 = 0; i15 < pairArr.length; i15++) {
                try {
                    if (pairArr != null && pairArr[i15] != null) {
                        Pair pair = pairArr[i15];
                        Counter[] counterArr2 = (Counter[]) pair.first;
                        double[] dArr2 = (double[]) pair.second;
                        for (int i16 = 0; i16 < counterArr2.length; i16++) {
                            empty.addInPlace(new TrainingStatistics.SentenceKey(kBPDataset.sentenceGlossKeys[i15][i16]), new TrainingStatistics.SentenceStatistics(Counters.exp(counterArr2[i16]), Math.exp(dArr2[i16])));
                        }
                    }
                } catch (OutOfMemoryError e6) {
                    Redwood.Util.err(new Object[]{e6});
                    empty = TrainingStatistics.empty();
                    Redwood.Util.endTrack("Computing Statistics");
                }
            }
            Redwood.Util.endTrack("Computing Statistics");
            this.statistics = Maybe.Just(empty);
            return empty;
        } catch (Throwable th) {
            Redwood.Util.endTrack("Computing Statistics");
            throw th;
        }
    }

    private int[] randomizeGroup(int[][] iArr, Maybe<String>[] maybeArr, int i) {
        if (!$assertionsDisabled && iArr.length != maybeArr.length) {
            throw new AssertionError();
        }
        Random random = new Random(i);
        int[] iArr2 = new int[iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr2[i2] = i2;
        }
        if (Props.HACKS_SQUASHRANDOM) {
            return iArr2;
        }
        for (int length = iArr.length - 1; length > 0; length--) {
            int nextInt = random.nextInt(length);
            int[] iArr3 = iArr[nextInt];
            iArr[nextInt] = iArr[length];
            iArr[length] = iArr3;
            Maybe<String> maybe = maybeArr[nextInt];
            maybeArr[nextInt] = maybeArr[length];
            maybeArr[length] = maybe;
            int i3 = iArr2[nextInt];
            iArr2[nextInt] = iArr2[length];
            iArr2[length] = i3;
        }
        return iArr2;
    }

    Triple<Double, Double, Double> computeYScore(String str, int[][] iArr, Set<Integer>[] setArr) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int indexOf = this.zLabelIndex.indexOf("_NR");
        if (indexOf < 0) {
            throw new IllegalStateException("Unknown relation: _NR");
        }
        for (int i6 = 0; i6 < setArr.length; i6++) {
            HashSet hashSet = new HashSet();
            for (int i7 : iArr[i6]) {
                if (i7 != indexOf) {
                    hashSet.add(Integer.valueOf(i7));
                }
            }
            Set<Integer> set = setArr[i6];
            i2 += hashSet.size();
            i3 += set.size();
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                if (set.contains(Integer.valueOf(((Integer) it.next()).intValue()))) {
                    i++;
                }
            }
            i5++;
            boolean z = true;
            if (hashSet.size() == set.size()) {
                Iterator it2 = hashSet.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    if (!set.contains(Integer.valueOf(((Integer) it2.next()).intValue()))) {
                        z = false;
                        break;
                    }
                }
            } else {
                z = false;
            }
            if (z) {
                i4++;
            }
        }
        double d = i / i2;
        double d2 = i / i3;
        double d3 = (d == 0.0d || d2 == 0.0d) ? 0.0d : ((2.0d * d) * d2) / (d + d2);
        double d4 = i4 / i5;
        logger.log(new Object[]{Redwood.Util.BLUE, Redwood.Util.BOLD, "LABEL SCORE for " + str + ": P " + d + " R " + d2 + " F1 " + d3});
        logger.log(new Object[]{Redwood.Util.BLUE, Redwood.Util.BOLD, "GROUP SCORE for " + str + ": A " + d4});
        return Triple.makeTriple(Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d4));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [int[], int[][]] */
    public Triple<Double, Double, Double> trainingAccuracy(KBPDataset<String, String> kBPDataset) {
        int[][][] dataArray = kBPDataset.getDataArray();
        ?? r0 = new int[dataArray.length];
        for (int i = 0; i < dataArray.length; i++) {
            r0[i] = new int[dataArray[i].length];
            List<Datum<String, String>> datumGroup = kBPDataset.getDatumGroup(i);
            for (int i2 = 0; i2 < datumGroup.size(); i2++) {
                r0[i][i2] = this.zLabelIndex.indexOf(Counters.argmax(classifyLocally(datumGroup.get(i2).asFeatures())));
            }
        }
        Set<Integer>[] positiveLabelsArray = kBPDataset.getPositiveLabelsArray();
        Set[] setArr = (Set[]) ErasureUtils.uncheckedCast(new Set[positiveLabelsArray.length]);
        for (int i3 = 0; i3 < positiveLabelsArray.length; i3++) {
            setArr[i3] = new HashSet();
            Iterator<Integer> it = positiveLabelsArray[i3].iterator();
            while (it.hasNext()) {
                setArr[i3].add(Integer.valueOf(this.zLabelIndex.indexOf(kBPDataset.labelIndex().get(it.next().intValue()))));
            }
        }
        computeConfusionMatrixForCounts("TRAIN", r0, setArr);
        return computeYScore("Z ONLY", r0, setArr);
    }

    @Deprecated
    void computeConfusionMatrixForCounts_(String str, int[][] iArr, Set<Integer>[] setArr) {
        ClassicCounter classicCounter = new ClassicCounter();
        ClassicCounter classicCounter2 = new ClassicCounter();
        int indexOf = this.zLabelIndex.indexOf("_NR");
        if (indexOf < 0) {
            throw new IllegalStateException("Unknown relation: _NR");
        }
        for (int i = 0; i < iArr.length; i++) {
            int[] iArr2 = iArr[i];
            ClassicCounter classicCounter3 = new ClassicCounter();
            for (int i2 : iArr2) {
                if (i2 != indexOf) {
                    classicCounter3.incrementCount(Integer.valueOf(i2));
                }
            }
            Set<Integer> set = setArr[i];
            Iterator it = classicCounter3.keySet().iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                int count = (int) classicCounter3.getCount(Integer.valueOf(intValue));
                if (set.contains(Integer.valueOf(intValue))) {
                    classicCounter.incrementCount(Integer.valueOf(count));
                } else {
                    classicCounter2.incrementCount(Integer.valueOf(count));
                }
            }
        }
        logger.log(new Object[]{"CONFUSION MATRIX for " + str});
        logger.log(new Object[]{"CONFUSION MATRIX POS: " + classicCounter});
        logger.log(new Object[]{"CONFUSION MATRIX NEG: " + classicCounter2});
    }

    void computeConfusionMatrixForCounts(String str, int[][] iArr, Set<Integer>[] setArr) {
        int size = this.zLabelIndex.size();
        Double[][] dArr = new Double[size][size];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                dArr[i][i2] = Double.valueOf(0.0d);
            }
        }
        if (this.zLabelIndex.indexOf("_NR") < 0) {
            throw new IllegalStateException("Unknown relation: _NR");
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            HashSet hashSet = new HashSet();
            for (int i4 : iArr[i3]) {
                hashSet.add(Integer.valueOf(i4));
            }
            HashSet hashSet2 = new HashSet(setArr[i3]);
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                if (hashSet2.contains(Integer.valueOf(intValue))) {
                    Double[] dArr2 = dArr[intValue];
                    dArr2[intValue] = Double.valueOf(dArr2[intValue].doubleValue() + 1.0d);
                    hashSet2.remove(Integer.valueOf(intValue));
                    it.remove();
                }
            }
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                int intValue2 = ((Integer) it2.next()).intValue();
                Iterator it3 = hashSet2.iterator();
                while (it3.hasNext()) {
                    Double[] dArr3 = dArr[((Integer) it3.next()).intValue()];
                    dArr3[intValue2] = Double.valueOf(dArr3[intValue2].doubleValue() + (1.0d / hashSet2.size()));
                }
            }
        }
        if (!this.partOfEnsemble || !Props.TRAIN_JOINTBAYES_MULTITHREAD) {
            Redwood.Util.startTrack(new Object[]{"confusion matrix (" + str + ")"});
        }
        logger.log(new Object[]{StringUtils.makeTextTable(dArr, this.zLabelIndex.objectsList().toArray(), this.zLabelIndex.objectsList().toArray(), 0, 1, true)});
        if (this.partOfEnsemble && Props.TRAIN_JOINTBAYES_MULTITHREAD) {
            return;
        }
        Redwood.endTrack("confusion matrix (" + str + ")");
    }

    void printGroup(int[] iArr, Set<Integer> set) {
        System.err.print("ZS:");
        for (int i : iArr) {
            System.err.print(" " + ((String) this.zLabelIndex.get(i)));
        }
        logger.log(new Object[0]);
        HashSet hashSet = new HashSet();
        System.err.print("YS:");
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            String str = (String) this.yLabelIndex.get(it.next().intValue());
            System.err.print(" " + str);
            boolean z = false;
            int length = iArr.length;
            int i2 = 0;
            while (true) {
                if (i2 >= length) {
                    break;
                }
                if (((String) this.zLabelIndex.get(iArr[i2])).equals(str)) {
                    z = true;
                    break;
                }
                i2++;
            }
            if (!z) {
                hashSet.add(str);
            }
        }
        logger.log(new Object[0]);
        if (hashSet.size() > 0) {
            System.err.print("MISSED " + hashSet.size() + ":");
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                System.err.print(" " + ((String) it2.next()));
            }
        }
        logger.log(new Object[0]);
        logger.log(new Object[]{"END GROUP"});
    }

    private void addYDatum(RVFDataset<String, String> rVFDataset, String str, int[] iArr, Counter<String>[] counterArr, boolean z) {
        rVFDataset.add(new RVFDatum(extractYFeatures(str, iArr), z ? str : "_NR"));
    }

    private String makeEpochPath(int i) {
        String str = null;
        if (i < this.numberOfTrainEpochs && this.serializedModelPath != null) {
            str = this.serializedModelPath.endsWith(Props.SER_EXT) ? this.serializedModelPath.substring(0, this.serializedModelPath.length() - Props.SER_EXT.length()) + "_EPOCH" + i + Props.SER_EXT : this.serializedModelPath + "_EPOCH" + i;
        }
        return str;
    }

    private void makeSingleZClassifier(Dataset<String, String> dataset, LinearClassifierFactory<String, String> linearClassifierFactory) {
        if (localClassificationMode != LOCAL_CLASSIFICATION_MODE.SINGLE_MODEL) {
            this.zSingleClassifier = null;
        } else {
            logger.log(new Object[]{"Training the final Z classifier..."});
            this.zSingleClassifier = linearClassifierFactory.trainClassifierWithInitialWeights(dataset, this.zSingleClassifier);
        }
    }

    private int[] flatten(int[][] iArr, int i) {
        int[] iArr2 = new int[i];
        int i2 = 0;
        for (int[] iArr3 : iArr) {
            for (int i3 : iArr3) {
                int i4 = i2;
                i2++;
                iArr2[i4] = i3;
            }
        }
        return iArr2;
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [int[], int[][]] */
    private static Dataset<String, String> makeLocalData(int[][][] iArr, Set<Integer>[] setArr, Set<Integer>[] setArr2, Index<String> index, Index<String> index2, LocalFilter localFilter, int i) {
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < iArr.length; i5++) {
            if (localFilter.filterZ(iArr[i5], setArr[i5])) {
                if (setArr[i5].size() != 0 || setArr2[i5].size() <= 0) {
                    i2 += iArr[i5].length * setArr[i5].size();
                    i3++;
                } else {
                    i2 += iArr[i5].length;
                    i4++;
                }
            }
        }
        logger.log(new Object[]{"Explored " + i3 + " positive groups and " + i4 + " negative groups, yielding " + i2 + " flat/local datums."});
        ?? r0 = new int[i2];
        int[] iArr2 = new int[i2];
        float[] fArr = new float[i2];
        int i6 = 0;
        int i7 = 0;
        HashSet hashSet = new HashSet();
        int indexOf = index.indexOf("_NR");
        if (indexOf < 0) {
            throw new IllegalStateException("Unknown relation: _NR");
        }
        hashSet.add(Integer.valueOf(indexOf));
        for (int i8 = 0; i8 < iArr.length; i8++) {
            if (localFilter.filterZ(iArr[i8], setArr[i8])) {
                int[][] iArr3 = iArr[i8];
                Set<Integer> set = setArr[i8];
                if (set.size() == 0) {
                    set = hashSet;
                }
                float size = 1.0f / set.size();
                for (Integer num : set) {
                    for (int[] iArr4 : iArr3) {
                        r0[i6] = iArr4;
                        iArr2[i6] = num.intValue();
                        fArr[i6] = size;
                        if (num.intValue() != indexOf) {
                            i7++;
                        }
                        i6++;
                        if (i6 >= i2) {
                            break;
                        }
                    }
                    if (i6 >= i2) {
                        break;
                    }
                }
                if (i6 >= i2) {
                    break;
                }
            }
        }
        WeightedDataset weightedDataset = new WeightedDataset(index, iArr2, index2, (int[][]) r0, r0.length, fArr);
        logger.log(new Object[]{"Fold #" + i + ": Constructed a dataset with " + r0.length + " datums, out of which " + i7 + " are positive."});
        if (i7 == 0) {
            throw new RuntimeException("ERROR: cannot handle a dataset with 0 positive examples!");
        }
        return weightedDataset;
    }

    private int[] makeTrainLabelArrayForFold(int[] iArr, int i) {
        int foldStart = foldStart(i, iArr.length);
        int foldEnd = foldEnd(i, iArr.length);
        int[] iArr2 = new int[(iArr.length - foldEnd) + foldStart];
        int i2 = 0;
        for (int i3 = 0; i3 < foldStart; i3++) {
            iArr2[i2] = iArr[i3];
            i2++;
        }
        for (int i4 = foldEnd; i4 < iArr.length; i4++) {
            iArr2[i2] = iArr[i4];
            i2++;
        }
        return iArr2;
    }

    /* JADX WARN: Type inference failed for: r0v8, types: [int[], int[][]] */
    private int[][] makeTrainDataArrayForFold(int[][] iArr, int i) {
        int foldStart = foldStart(i, iArr.length);
        int foldEnd = foldEnd(i, iArr.length);
        ?? r0 = new int[(iArr.length - foldEnd) + foldStart];
        int i2 = 0;
        for (int i3 = 0; i3 < foldStart; i3++) {
            r0[i2] = iArr[i3];
            i2++;
        }
        for (int i4 = foldEnd; i4 < iArr.length; i4++) {
            r0[i2] = iArr[i4];
            i2++;
        }
        return r0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<int[][][], int[][][]> makeDataArraysForFold(int[][][] iArr, int i) {
        int foldStart = foldStart(i, iArr.length);
        int foldEnd = foldEnd(i, iArr.length);
        int[][] iArr2 = new int[(iArr.length - foldEnd) + foldStart];
        int[][] iArr3 = new int[foldEnd - foldStart];
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (i4 < foldStart) {
                iArr2[i2] = iArr[i4];
                i2++;
            } else if (i4 < foldEnd) {
                iArr3[i3] = iArr[i4];
                i3++;
            } else {
                iArr2[i2] = iArr[i4];
                i2++;
            }
        }
        return new Pair<>(iArr2, iArr3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<Set<Integer>[], Set<Integer>[]> makeLabelSetsForFold(Set<Integer>[] setArr, int i) {
        int foldStart = foldStart(i, setArr.length);
        int foldEnd = foldEnd(i, setArr.length);
        HashSet[] hashSetArr = new HashSet[(setArr.length - foldEnd) + foldStart];
        HashSet[] hashSetArr2 = new HashSet[foldEnd - foldStart];
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < setArr.length; i4++) {
            if (i4 < foldStart) {
                hashSetArr[i2] = setArr[i4];
                i2++;
            } else if (i4 < foldEnd) {
                hashSetArr2[i3] = setArr[i4];
                i3++;
            } else {
                hashSetArr[i2] = setArr[i4];
                i2++;
            }
        }
        return new Pair<>(hashSetArr, hashSetArr2);
    }

    private LinearClassifierFactory<String, String> getZClassifierFactory() {
        LinearClassifierFactory<String, String> linearClassifierFactory = new LinearClassifierFactory<>(1.0E-4d, false, this.zSigma);
        switch (this.zClassifierMinimizerType) {
            case SGD:
                linearClassifierFactory.useInPlaceStochasticGradientDescent(75, Props.MAX_DISTANCE_BETWEEN_ENTITY_AND_SLOT, this.zSigma);
                break;
            case SGDTOQN:
                linearClassifierFactory.useHybridMinimizerWithInPlaceSGD(10, Props.MAX_DISTANCE_BETWEEN_ENTITY_AND_SLOT, this.zSigma);
                break;
        }
        return linearClassifierFactory;
    }

    private Runnable createLocalZClassifierInitializer(LinearClassifier<String, String>[] linearClassifierArr, KBPDataset<String, String> kBPDataset, Index<String> index, Index<String> index2, int i) {
        int[][][] dataArray = kBPDataset.getDataArray();
        Set<Integer>[] positiveLabelsArray = kBPDataset.getPositiveLabelsArray();
        Set<Integer>[] negativeLabelsArray = kBPDataset.getNegativeLabelsArray();
        return () -> {
            String str = "Initializing local Z classifier for fold #" + i;
            if (!this.partOfEnsemble || !Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                Redwood.Util.startTrack(new Object[]{str});
            }
            logger.log(new Object[]{"Constructing dataset for the local model in fold #" + i + "..."});
            Pair<int[][][], int[][][]> makeDataArraysForFold = makeDataArraysForFold(dataArray, i);
            Pair<Set<Integer>[], Set<Integer>[]> makeLabelSetsForFold = makeLabelSetsForFold(positiveLabelsArray, i);
            Pair<Set<Integer>[], Set<Integer>[]> makeLabelSetsForFold2 = makeLabelSetsForFold(negativeLabelsArray, i);
            int[][][] iArr = (int[][][]) makeDataArraysForFold.first();
            Set[] setArr = (Set[]) makeLabelSetsForFold.first();
            Set[] setArr2 = (Set[]) makeLabelSetsForFold2.first();
            int[][][] iArr2 = (int[][][]) makeDataArraysForFold.second();
            Set[] setArr3 = (Set[]) makeLabelSetsForFold.second();
            Dataset<String, String> makeLocalData = makeLocalData(iArr, setArr, setArr2, index2, index, this.localDataFilter, i);
            logger.log(new Object[]{"Fold #" + i + ": Training local model..."});
            LinearClassifier trainClassifier = getZClassifierFactory().trainClassifier(makeLocalData);
            logger.log(new Object[]{"Fold #" + i + ": Training of the local classifier completed."});
            int indexOf = index2.indexOf("_NR");
            if (indexOf < 0) {
                throw new IllegalStateException("Unknown relation: _NR");
            }
            logger.log(new Object[]{"Fold #" + i + ": Evaluating the local classifier on the hierarchical dataset..."});
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < iArr2.length; i5++) {
                int[][] iArr3 = iArr2[i5];
                Set set = setArr3[i5];
                HashSet hashSet = new HashSet();
                for (int[] iArr4 : iArr3) {
                    List<Pair<String, Double>> sortPredictions = sortPredictions((Counter<String>) trainClassifier.scoresOf(iArr4));
                    int indexOf2 = index2.indexOf(sortPredictions.get(0).first());
                    if (indexOf2 < 0) {
                        throw new IllegalStateException("Unknown relation: " + ((String) sortPredictions.get(0).first()));
                    }
                    if (indexOf2 != indexOf) {
                        hashSet.add(Integer.valueOf(indexOf2));
                    }
                }
                i2 += set.size();
                i3 += hashSet.size();
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    if (set.contains((Integer) it.next())) {
                        i4++;
                    }
                }
            }
            double d = i4 / i3;
            double d2 = i4 / i2;
            logger.log(new Object[]{"Fold #" + i + ": Training score on the hierarchical dataset: P " + d + " R " + d2 + " F1 " + ((d == 0.0d || d2 == 0.0d) ? 0.0d : ((2.0d * d) * d2) / (d + d2))});
            logger.log(new Object[]{"Fold #" + i + ": Created the Z classifier with " + index2.size() + " labels and " + index.size() + " features."});
            synchronized (this.lock) {
                linearClassifierArr[i] = trainClassifier;
            }
            if (this.partOfEnsemble && Props.TRAIN_JOINTBAYES_MULTITHREAD) {
                return;
            }
            Redwood.endTrack(str);
        };
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:2:0x001c. Please report as an issue. */
    private LinearClassifier<String, String>[] initializeZClassifierLocally(KBPDataset<String, String> kBPDataset, Index<String> index, Index<String> index2) {
        LinearClassifier<String, String>[] linearClassifierArr = new LinearClassifier[this.numberOfFolds];
        ArrayList arrayList = new ArrayList();
        switch (Props.TRAIN_JOINTBAYES_INITIALIZATION) {
            case SUPERVISED:
                if (Props.TRAIN_ANNOTATED_SENTENCES_DO) {
                    logger.log(new Object[]{"initializing from SUPERVISED classifier"});
                    LinearClassifier<String, String> asClassifier = new SupervisedExtractor(new Properties()).setIndices(index2, index).asClassifier(Maybe.Just(kBPDataset));
                    for (int i = 0; i < this.numberOfFolds; i++) {
                        linearClassifierArr[i] = asClassifier;
                    }
                    if ($assertionsDisabled && linearClassifierArr.length <= 0) {
                        throw new AssertionError();
                    }
                    if ($assertionsDisabled && linearClassifierArr[0] == null) {
                        throw new AssertionError();
                    }
                    return linearClassifierArr;
                }
                break;
            case DISTSUP:
                logger.log(new Object[]{"initializing from DISTANT SUPERVISION classifier"});
                for (int i2 = 0; i2 < this.numberOfFolds; i2++) {
                    arrayList.add(createLocalZClassifierInitializer(linearClassifierArr, kBPDataset, index, index2, i2));
                }
                Redwood.Util.threadAndRun("Initialize local Z classifiers", arrayList, this.numberOfThreads);
                if ($assertionsDisabled) {
                    break;
                }
                if ($assertionsDisabled) {
                    break;
                }
                return linearClassifierArr;
            default:
                throw new IllegalStateException("Unknown initialization method: " + Props.TRAIN_JOINTBAYES_INITIALIZATION);
        }
    }

    private void loadInitialModels(String str) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
        if (this.featureIndex != null) {
            throw new IllegalStateException("Loading over a trained model!");
        }
        this.featureIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.zLabelIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.yLabelIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.numberOfFolds = objectInputStream.readInt();
        this.zClassifiers = new LinearClassifier[this.numberOfFolds];
        for (int i = 0; i < this.numberOfFolds; i++) {
            this.zClassifiers[i] = (LinearClassifier) ErasureUtils.uncheckedCast(objectInputStream.readObject());
            logger.log(new Object[]{"Loaded Z classifier for fold #" + i + ": " + this.zClassifiers[i]});
        }
        int readInt = objectInputStream.readInt();
        this.yClassifiers = new HashMap();
        for (int i2 = 0; i2 < readInt; i2++) {
            this.yClassifiers.put((String) ErasureUtils.uncheckedCast(objectInputStream.readObject()), (LinearClassifier) ErasureUtils.uncheckedCast(objectInputStream.readObject()));
        }
        objectInputStream.close();
    }

    private void saveInitialModels(String str) throws IOException {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
        objectOutputStream.writeObject(this.featureIndex);
        objectOutputStream.writeObject(this.zLabelIndex);
        objectOutputStream.writeObject(this.yLabelIndex);
        objectOutputStream.writeInt(this.zClassifiers.length);
        for (int i = 0; i < this.zClassifiers.length; i++) {
            objectOutputStream.writeObject(this.zClassifiers[i]);
        }
        objectOutputStream.writeInt(this.yClassifiers.keySet().size());
        for (String str2 : this.yClassifiers.keySet()) {
            objectOutputStream.writeObject(str2);
            objectOutputStream.writeObject(this.yClassifiers.get(str2));
        }
        objectOutputStream.close();
    }

    private Map<String, LinearClassifier<String, String>> initializeYClassifiersWithAtLeastOnce(Index<String> index) {
        HashMap hashMap = new HashMap();
        for (String str : index) {
            HashIndex hashIndex = new HashIndex();
            hashIndex.addAll(Y_FEATURES_FOR_INITIAL_MODEL);
            HashIndex hashIndex2 = new HashIndex();
            hashIndex2.add(str);
            hashIndex2.add("_NR");
            double[][] initializeWeights = initializeWeights(hashIndex.size(), hashIndex2.size());
            setYWeightsForAtLeastOnce(initializeWeights, hashIndex, hashIndex2);
            hashMap.put(str, new LinearClassifier(initializeWeights, hashIndex, hashIndex2));
            logger.log(new Object[]{"Created the classifier for Y=" + str + " with " + hashIndex.size() + " features"});
        }
        return hashMap;
    }

    private static void setYWeightsForAtLeastOnce(double[][] dArr, Index<String> index, Index<String> index2) {
        int i = -1;
        int i2 = -1;
        for (String str : index2) {
            if (str.equalsIgnoreCase("_NR")) {
                i2 = index2.indexOf(str);
                if (i2 < 0) {
                    throw new IllegalStateException("Unknown relation: " + str);
                }
            } else {
                Redwood.Util.debug(new Object[]{"posLabel = " + str});
                i = index2.indexOf(str);
                if (i < 0) {
                    throw new IllegalStateException("Unknown relation: " + str);
                }
            }
        }
        if (!$assertionsDisabled && i == -1) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && i2 == -1) {
            throw new AssertionError();
        }
        int indexOf = index.indexOf(ATLEASTONCE_FEAT);
        int indexOf2 = index.indexOf(NONE_FEAT);
        if (index.indexOf(ATLEASTONCE_FEAT) >= 0) {
            dArr[index.indexOf(ATLEASTONCE_FEAT)][i] = 10.0d;
        }
        if (index.indexOf(SIGMOID_FEAT) >= 0) {
            dArr[index.indexOf(SIGMOID_FEAT)][i] = 10.0d;
        }
        dArr[indexOf2][i2] = 10.0d;
        Redwood.Util.debug(new Object[]{"posLabel = " + i + ", negLabel = " + i2 + ", atLeastOnceIndex = " + indexOf});
    }

    private static double[][] initializeWeights(int i, int i2) {
        double[][] dArr = new double[i][i2];
        for (double[] dArr2 : dArr) {
            Arrays.fill(dArr2, 0.0d);
        }
        return dArr;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
    private Dataset<String, String> initializeZDataset(int i, int[][] iArr, int[][][] iArr2) {
        ?? r0 = new int[i];
        int i2 = 0;
        for (int i3 = 0; i3 < iArr2.length; i3++) {
            for (int i4 = 0; i4 < iArr2[i3].length; i4++) {
                int i5 = i2;
                i2++;
                r0[i5] = iArr2[i3][i4];
            }
        }
        int[] flatten = flatten(iArr, i);
        logger.log(new Object[]{"Created the Z dataset with " + flatten.length + " datums."});
        return new Dataset<>(this.zLabelIndex, flatten, this.featureIndex, (int[][]) r0);
    }

    private Map<String, RVFDataset<String, String>> initializeYDatasets() {
        HashMap hashMap = new HashMap();
        Iterator it = this.yLabelIndex.objectsList().iterator();
        while (it.hasNext()) {
            hashMap.put((String) it.next(), new RVFDataset());
        }
        return hashMap;
    }

    protected Counter<Integer> computeYLogProbs(LinearClassifier<String, String> linearClassifier, int[][] iArr, Set<Integer> set) {
        int[] iArr2 = new int[iArr.length];
        predictZLabels(iArr, iArr2, linearClassifier);
        return computeYLogProbs(iArr2, set);
    }

    protected Counter<Integer> computeYLogProbs(int[] iArr, Set<Integer> set) {
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            String str = (String) this.yLabelIndex.get(intValue);
            classicCounter.setCount(Integer.valueOf(intValue), this.yClassifiers.get(str).logProbabilityOf(new RVFDatum(extractYFeatures(str, iArr), "")).getCount(str));
        }
        return classicCounter;
    }

    private void predictZLabels(int[][] iArr, int[] iArr2, LinearClassifier<String, String> linearClassifier) {
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = this.zLabelIndex.indexOf(Counters.argmax(linearClassifier.logProbabilityOf(iArr[i])));
        }
    }

    private void computeZLogProbs(int[][] iArr, Counter<String>[] counterArr, Maybe<String>[] maybeArr, LinearClassifier<String, String> linearClassifier, int i) {
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (!maybeArr[i2].isDefined() || Props.TRAIN_JOINTBAYES_TURKERTRUST <= 0.0d) {
                counterArr[i2] = linearClassifier.logProbabilityOf(iArr[i2]);
            } else {
                if (Props.TRAIN_JOINTBAYES_TURKERTRUST < 1.0d) {
                    counterArr[i2] = linearClassifier.logProbabilityOf(iArr[i2]);
                    Counters.expInPlace(counterArr[i2]);
                    counterArr[i2].remove(maybeArr[i2].get());
                    Counters.normalize(counterArr[i2]);
                    Counters.multiplyInPlace(counterArr[i2], 1.0d - Props.TRAIN_JOINTBAYES_TURKERTRUST);
                } else {
                    counterArr[i2] = new ClassicCounter();
                }
                counterArr[i2].setCount(maybeArr[i2].get(), Props.TRAIN_JOINTBAYES_TURKERTRUST);
                Counters.normalize(counterArr[i2]);
                Counters.logInPlace(counterArr[i2]);
                if (Props.TRAIN_JOINTBAYES_TURKERTRUST == 1.0d) {
                    if (!$assertionsDisabled && Math.abs(counterArr[i2].getCount(maybeArr[i2].get())) >= 1.0E-4d) {
                        throw new AssertionError();
                    }
                    counterArr[i2].setCount(maybeArr[i2].get(), 0.0d);
                }
            }
            if (!$assertionsDisabled && Math.abs(Counters.exp(counterArr[i2]).totalCount() - 1.0d) >= 0.001d) {
                throw new AssertionError();
            }
        }
    }

    private double[] inferZLabelsStable(int[][] iArr, Set<Integer> set, Set<Integer> set2, int[] iArr2, Maybe<String>[] maybeArr, Counter<String>[] counterArr, LinearClassifier<String, String> linearClassifier, int i) {
        if (!$assertionsDisabled && maybeArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && iArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && counterArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (0 != 0) {
            logger.log(new Object[]{"inferZLabels: "});
            if (set.size() > 1) {
                logger.log(new Object[]{"MULTI RELATION"});
            } else if (set.size() == 1) {
                logger.log(new Object[]{"SINGLE RELATION"});
            } else {
                logger.log(new Object[]{"NIL RELATION"});
            }
            logger.log(new Object[]{"positiveLabels: " + set});
            logger.log(new Object[]{"negativeLabels: " + set2});
            System.err.print("Current zLabels:");
            for (int i2 : iArr2) {
                System.err.print(" " + i2);
            }
            logger.log(new Object[0]);
        }
        Counter<String>[] counterArr2 = (Counter[]) ErasureUtils.uncheckedCast(new Counter[iArr.length]);
        computeZLogProbs(iArr, counterArr2, maybeArr, linearClassifier, i);
        double[] dArr = new double[iArr.length];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            dArr[i3] = Double.NEGATIVE_INFINITY;
            int i4 = -1;
            Counter<String> counter = counterArr2[i3];
            if (!$assertionsDisabled && Math.abs(Counters.exp(counter).totalCount() - 1.0d) >= 1.0E-5d) {
                throw new AssertionError();
            }
            ClassicCounter classicCounter = new ClassicCounter();
            int i5 = iArr2[i3];
            for (String str : counter.keySet()) {
                int indexOf = this.zLabelIndex.indexOf(str);
                if (0 != 0) {
                    logger.log(new Object[]{"\tProbabilities for z[" + i3 + "]:"});
                }
                double count = counter.getCount(str);
                iArr2[i3] = indexOf;
                if (0 != 0) {
                    logger.log(new Object[]{"\t\tlocal (" + iArr2[i3] + ") = " + count});
                }
                Iterator<Integer> it = set.iterator();
                while (it.hasNext()) {
                    int intValue = it.next().intValue();
                    String str2 = (String) this.yLabelIndex.get(intValue);
                    double count2 = this.yClassifiers.get(str2).logProbabilityOf(new RVFDatum(extractYFeatures(str2, iArr2), "")).getCount(str2);
                    if (0 != 0) {
                        logger.log(new Object[]{"\t\t\ty+ (" + intValue + ") = " + count2});
                    }
                    count += count2;
                }
                Iterator<Integer> it2 = set2.iterator();
                while (it2.hasNext()) {
                    int intValue2 = it2.next().intValue();
                    String str3 = (String) this.yLabelIndex.get(intValue2);
                    double count3 = this.yClassifiers.get(str3).logProbabilityOf(new RVFDatum(extractYFeatures(str3, iArr2), "")).getCount("_NR");
                    if (0 != 0) {
                        logger.log(new Object[]{"\t\t\ty- (" + intValue2 + ") = " + count3});
                    }
                    count += count3;
                }
                if (0 != 0) {
                    logger.log(new Object[]{"\t\ttotal (" + iArr2[i3] + ") = " + count});
                }
                classicCounter.setCount(str, count);
                if (count > dArr[i3]) {
                    dArr[i3] = count;
                    i4 = iArr2[i3];
                }
            }
            if (i4 == -1 || i4 == i5) {
                iArr2[i3] = i5;
            } else {
                if (0 != 0) {
                    logger.log(new Object[]{"\tNEW zLabels[" + i3 + "] = " + i4});
                }
                iArr2[i3] = i4;
                this.zUpdatesInOneEpoch.getAndIncrement();
            }
            Counters.logNormalizeInPlace(classicCounter);
            counterArr[i3] = classicCounter;
        }
        return dArr;
    }

    private double[] inferZLabels(int[][] iArr, Set<Integer> set, Set<Integer> set2, int[] iArr2, Maybe<String>[] maybeArr, Counter<String>[] counterArr, LinearClassifier<String, String> linearClassifier, int i) {
        if (!$assertionsDisabled && maybeArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && iArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && counterArr.length != iArr2.length) {
            throw new AssertionError();
        }
        if (0 != 0) {
            System.err.print("inferZLabels: ");
            if (set.size() > 1) {
                logger.log(new Object[]{"MULTI RELATION"});
            } else if (set.size() == 1) {
                logger.log(new Object[]{"SINGLE RELATION"});
            } else {
                logger.log(new Object[]{"NIL RELATION"});
            }
            logger.log(new Object[]{"positiveLabels: " + set});
            logger.log(new Object[]{"negativeLabels: " + set2});
            System.err.print("Current zLabels:");
            for (int i2 : iArr2) {
                System.err.print(" " + i2);
            }
            logger.log(new Object[0]);
        }
        computeZLogProbs(iArr, counterArr, maybeArr, linearClassifier, i);
        HashSet hashSet = new HashSet();
        double[] dArr = new double[iArr.length];
        while (true) {
            double d = Double.NEGATIVE_INFINITY;
            int i3 = -1;
            int i4 = -1;
            for (int i5 = 0; i5 < iArr.length; i5++) {
                if (!hashSet.contains(Integer.valueOf(i5))) {
                    dArr[i5] = Double.NEGATIVE_INFINITY;
                    int i6 = -1;
                    Counter<String> counter = counterArr[i5];
                    List<String> sortRelationsByPrior = Utils.sortRelationsByPrior(counter.keySet());
                    ClassicCounter classicCounter = new ClassicCounter();
                    int i7 = iArr2[i5];
                    for (String str : sortRelationsByPrior) {
                        if (0 != 0) {
                            logger.log(new Object[]{"\tProbabilities for z[" + i5 + "]:"});
                        }
                        double count = counter.getCount(str);
                        iArr2[i5] = this.zLabelIndex.indexOf(str);
                        if (0 != 0) {
                            logger.log(new Object[]{"\t\tlocal (" + iArr2[i5] + ") = " + count});
                        }
                        Iterator<Integer> it = set.iterator();
                        while (it.hasNext()) {
                            int intValue = it.next().intValue();
                            String str2 = (String) this.yLabelIndex.get(intValue);
                            double count2 = this.yClassifiers.get(str2).logProbabilityOf(new RVFDatum(extractYFeatures(str2, iArr2), "")).getCount(str2);
                            if (0 != 0) {
                                logger.log(new Object[]{"\t\t\ty+ (" + intValue + ") = " + count2});
                            }
                            count += count2;
                        }
                        Iterator<Integer> it2 = set2.iterator();
                        while (it2.hasNext()) {
                            int intValue2 = it2.next().intValue();
                            String str3 = (String) this.yLabelIndex.get(intValue2);
                            double count3 = this.yClassifiers.get(str3).logProbabilityOf(new RVFDatum(extractYFeatures(str3, iArr2), "")).getCount("_NR");
                            if (0 != 0) {
                                logger.log(new Object[]{"\t\t\ty- (" + intValue2 + ") = " + count3});
                            }
                            count += count3;
                        }
                        if (0 != 0) {
                            logger.log(new Object[]{"\t\ttotal (" + iArr2[i5] + ") = " + count});
                        }
                        classicCounter.setCount(str, count);
                        if (count > dArr[i5]) {
                            dArr[i5] = count;
                            i6 = iArr2[i5];
                        }
                    }
                    iArr2[i5] = i7;
                    if (i6 != -1 && i6 != iArr2[i5] && !uniformDistribution(classicCounter) && dArr[i5] > d) {
                        d = dArr[i5];
                        i3 = i6;
                        i4 = i5;
                    }
                    Counters.logNormalizeInPlace(classicCounter);
                    counterArr[i5] = classicCounter;
                }
            }
            if (i3 == -1) {
                boolean z = false;
                Iterator<Integer> it3 = set.iterator();
                while (true) {
                    if (!it3.hasNext()) {
                        break;
                    }
                    Integer next = it3.next();
                    boolean z2 = false;
                    int i8 = 0;
                    while (true) {
                        if (i8 >= iArr2.length) {
                            break;
                        }
                        if (iArr2[i8] == next.intValue()) {
                            z2 = true;
                            break;
                        }
                        i8++;
                    }
                    if (!z2) {
                        z = true;
                        break;
                    }
                }
                if (0 != 0 && z) {
                    if (iArr2.length < set.size()) {
                        logger.log(new Object[]{"FOUND MISSED Y, smaller Z"});
                    } else {
                        logger.log(new Object[]{"FOUND MISSED Y, larger Z"});
                    }
                }
                return dArr;
            }
            if (!$assertionsDisabled && i4 == -1) {
                throw new AssertionError();
            }
            iArr2[i4] = i3;
            this.zUpdatesInOneEpoch.getAndIncrement();
            hashSet.add(Integer.valueOf(i4));
            if (0 != 0) {
                logger.log(new Object[]{"\tNEW zLabels[" + i4 + "] = " + iArr2[i4]});
            }
        }
    }

    private static boolean uniformDistribution(Counter<String> counter) {
        ArrayList arrayList = new ArrayList(counter.keySet());
        if (arrayList.size() < 2) {
            return false;
        }
        double count = counter.getCount(arrayList.get(0));
        for (int i = 1; i < arrayList.size(); i++) {
            if (count != counter.getCount(arrayList.get(i))) {
                return false;
            }
        }
        return true;
    }

    private Counter<String> extractYFeatures(String str, int[] iArr) {
        return extractYFeatures(str, iArr, null);
    }

    private Counter<String> extractYFeatures(String str, int[] iArr, Counter<String>[] counterArr) {
        if (!$assertionsDisabled && counterArr != null) {
            throw new AssertionError();
        }
        int i = 0;
        String[] strArr = new String[iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            String str2 = (String) this.zLabelIndex.get(iArr[i2]);
            if (str2.equals(str)) {
                i++;
            } else if (!str2.equals("_NR")) {
                strArr[i2] = str2;
            }
        }
        ClassicCounter classicCounter = new ClassicCounter();
        if (i == 0) {
            classicCounter.setCount(NONE_FEAT, 1.0d);
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.ATLEAST_ONCE) && i > 0) {
            classicCounter.setCount(ATLEASTONCE_FEAT, 1.0d);
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.COOC) && i > 0) {
            for (String str3 : strArr) {
                if (str3 != null) {
                    classicCounter.setCount(makeCoocurrenceFeature(str, str3), 1.0d);
                }
            }
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.UNIQUE) && i > 0) {
            boolean z = true;
            for (String str4 : strArr) {
                z &= str4 == null;
            }
            if (z) {
                classicCounter.setCount(UNIQUE_FEAT, 1.0d);
            }
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.ATLEAST_N) && i > 0) {
            classicCounter.setCount("atleast_" + i, 1.0d);
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.SIGMOID) && i > 0) {
            classicCounter.setCount(SIGMOID_FEAT, 1.0d / (1.0d + Math.exp((-10.0d) * ((i / iArr.length) - 0.3333333333333333d))));
        }
        return classicCounter;
    }

    public static List<Pair<String, Double>> sortPredictions(Counter<String> counter) {
        ArrayList arrayList = new ArrayList();
        for (String str : counter.keySet()) {
            arrayList.add(new Pair(str, Double.valueOf(counter.getCount(str))));
        }
        sortPredictions(arrayList);
        return arrayList;
    }

    private static void sortPredictions(List<Pair<String, Double>> list) {
        Collections.sort(list, (pair, pair2) -> {
            if (((Double) pair.second()).doubleValue() > ((Double) pair2.second()).doubleValue()) {
                return -1;
            }
            if (!((Double) pair.second()).equals(pair2.second())) {
                return 1;
            }
            int compareTo = ((String) pair.first()).compareTo((String) pair2.first());
            if (compareTo < 0) {
                return -1;
            }
            return compareTo == 0 ? 0 : 1;
        });
    }

    private Counter<String> classifyLocally(Collection<String> collection) {
        BasicDatum basicDatum = new BasicDatum(collection);
        if (localClassificationMode != LOCAL_CLASSIFICATION_MODE.WEIGHTED_VOTE) {
            if (localClassificationMode == LOCAL_CLASSIFICATION_MODE.SINGLE_MODEL) {
                return this.zSingleClassifier.probabilityOf(basicDatum);
            }
            throw new RuntimeException("ERROR: classification mode " + localClassificationMode + " not supported!");
        }
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i = 0; i < this.numberOfFolds; i++) {
            classicCounter.addAll(this.zClassifiers[i].probabilityOf(basicDatum));
        }
        for (String str : classicCounter.keySet()) {
            classicCounter.setCount(str, classicCounter.getCount(str) / this.numberOfFolds);
        }
        return classicCounter;
    }

    public Counter<String> classifyOracleMentions(List<Collection<String>> list, Set<String> set) {
        Counter[] counterArr = (Counter[]) ErasureUtils.uncheckedCast(new Counter[list.size()]);
        ClassicCounter classicCounter = new ClassicCounter();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            counterArr[i] = classifyLocally(list.get(i));
            if (!uniformDistribution(counterArr[i])) {
                List<Pair<String, Double>> sortPredictions = sortPredictions((Counter<String>) counterArr[i]);
                double doubleValue = ((Double) sortPredictions.get(0).second()).doubleValue();
                for (int i2 = 0; i2 < sortPredictions.size() && i2 < 3; i2++) {
                    String str = (String) sortPredictions.get(i2).first();
                    if (((Double) sortPredictions.get(i2).second()).doubleValue() + 0.99d < doubleValue) {
                        break;
                    }
                    if (!str.equals("_NR")) {
                        double d = 1.0d / (1.0d + i2);
                        ClassicCounter classicCounter2 = (Counter) hashMap.get(str);
                        if (classicCounter2 == null) {
                            classicCounter2 = new ClassicCounter();
                            hashMap.put(str, classicCounter2);
                        }
                        classicCounter2.setCount(Integer.valueOf(i2), d);
                    }
                }
            }
        }
        for (String str2 : hashMap.keySet()) {
            double d2 = 0.0d;
            Iterator it = ((Counter) hashMap.get(str2)).keySet().iterator();
            while (it.hasNext()) {
                d2 += ((Counter) hashMap.get(str2)).getCount(Integer.valueOf(((Integer) it.next()).intValue()));
            }
            double size = d2 / list.size();
            logger.log(new Object[]{"RANK = " + size});
            if (size >= 0.0d) {
                classicCounter.setCount(str2, size);
            }
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public Pair<Double, Maybe<KBPRelationProvenance>> classifyRelation(SentenceGroup sentenceGroup, RelationType relationType, Maybe<CoreMap[]> maybe) {
        for (Map.Entry entry : classifyRelations(sentenceGroup, maybe, Props.TRAIN_JOINTBAYES_OUTDISTRIBUTION_TYPES.Y_GIVEN_ZSTAR).entrySet()) {
            if (((String) ((Pair) entry.getKey()).first).equals(relationType.canonicalName)) {
                return Pair.makePair(entry.getValue(), ((Pair) entry.getKey()).second);
            }
        }
        return Pair.makePair(Double.valueOf(0.0d), Maybe.Nothing());
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        return classifyRelations(sentenceGroup, maybe, Props.TRAIN_JOINTBAYES_OUTDISTRIBUTION);
    }

    private Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe, Props.TRAIN_JOINTBAYES_OUTDISTRIBUTION_TYPES train_jointbayes_outdistribution_types) {
        List<Collection<String>> tupleToFeatureList = RelationClassifier.tupleToFeatureList(sentenceGroup);
        String[] strArr = new String[tupleToFeatureList.size()];
        Counter[] counterArr = (Counter[]) ErasureUtils.uncheckedCast(new Counter[tupleToFeatureList.size()]);
        ClassicCounter classicCounter = new ClassicCounter();
        HashMap hashMap = new HashMap();
        ClassicCounter classicCounter2 = new ClassicCounter();
        ClassicCounter classicCounter3 = new ClassicCounter();
        ClassicCounter classicCounter4 = new ClassicCounter();
        for (int i = 0; i < tupleToFeatureList.size(); i++) {
            Counter<String> classifyLocally = classifyLocally(tupleToFeatureList.get(i));
            counterArr[i] = new ClassicCounter();
            for (String str : classifyLocally.keySet()) {
                counterArr[i].setCount(str, classifyLocally.getCount(str));
            }
            Pair<String, Double> pair = sortPredictions(classifyLocally).get(0);
            String str2 = (String) pair.first();
            double doubleValue = ((Double) pair.second()).doubleValue();
            strArr[i] = str2;
            if (!str2.equals("_NR")) {
                classicCounter2.incrementCount(str2, doubleValue);
                if (!classicCounter3.containsKey(str2) || doubleValue > classicCounter3.getCount(str2)) {
                    classicCounter3.setCount(str2, doubleValue);
                    if (sentenceGroup.getProvenance(i).isOfficial()) {
                        hashMap.put(str2, sentenceGroup.getProvenance(i).rewrite(doubleValue));
                    }
                }
                if (!hashMap.containsKey(str2) && sentenceGroup.getProvenance(i).isOfficial()) {
                    hashMap.put(str2, sentenceGroup.getProvenance(i));
                }
                classicCounter4.setCount(str2, (classicCounter4.containsKey(str2) ? classicCounter4.getCount(str2) : 1.0d) * (1.0d - doubleValue));
            }
        }
        int[] iArr = new int[strArr.length];
        for (int i2 = 0; i2 < strArr.length; i2++) {
            iArr[i2] = this.zLabelIndex.indexOf(strArr[i2]);
            if (iArr[i2] < 0) {
                throw new IllegalStateException("Unknown label: " + strArr[i2]);
            }
        }
        Counters.multiplyInPlace(classicCounter4, -1.0d);
        Counters.addInPlace(classicCounter4, 1.0d);
        ClassicCounter classicCounter5 = new ClassicCounter();
        for (String str3 : this.yClassifiers.keySet()) {
            try {
                Counter probabilityOf = this.yClassifiers.get(str3).probabilityOf(new RVFDatum(extractYFeatures(str3, iArr), ""));
                double count = probabilityOf.getCount(str3);
                double count2 = count / (count + probabilityOf.getCount("_NR"));
                classicCounter.setCount(str3, count2);
                if ((!Props.TEST_THRESHOLD_JOINTBAYES_PERRELATION.containsKey(str3) && count2 > Props.TEST_THRESHOLD_JOINTBAYES_DEFAULT) || (Props.TEST_THRESHOLD_JOINTBAYES_PERRELATION.containsKey(str3) && count2 > Props.TEST_THRESHOLD_JOINTBAYES_PERRELATION.get(str3).doubleValue())) {
                    classicCounter5.incrementCount(str3, count2);
                }
            } catch (Exception e) {
                logger.err(new Object[]{e});
            }
        }
        ClassicCounter classicCounter6 = new ClassicCounter();
        for (String str4 : classicCounter.keySet()) {
            double count3 = classicCounter4.getCount(str4);
            switch (train_jointbayes_outdistribution_types) {
                case Y_GIVEN_ZSTAR:
                    classicCounter6.setCount(Pair.makePair(str4, Maybe.fromNull(hashMap.get(str4))), classicCounter.getCount(str4));
                    break;
                case NOISY_OR:
                    if ((!Props.TEST_THRESHOLD_JOINTBAYES_PERRELATION.containsKey(str4) && 1.0d * count3 > Props.TEST_THRESHOLD_JOINTBAYES_DEFAULT) || (Props.TEST_THRESHOLD_JOINTBAYES_PERRELATION.containsKey(str4) && 1.0d * count3 > Props.TEST_THRESHOLD_JOINTBAYES_PERRELATION.get(str4).doubleValue())) {
                        classicCounter6.setCount(Pair.makePair(str4, Maybe.fromNull(hashMap.get(str4))), 1.0d * count3);
                        break;
                    }
                    break;
                case Y_THEN_NOISY_OR:
                    if (classicCounter5.containsKey(str4)) {
                        classicCounter6.setCount(Pair.makePair(str4, Maybe.fromNull(hashMap.get(str4))), 1.0d * count3);
                        break;
                    } else {
                        break;
                    }
                default:
                    throw new IllegalStateException("Unknown output type: " + Props.TRAIN_JOINTBAYES_OUTDISTRIBUTION);
            }
        }
        if (Props.TRAIN_JOINTBAYES_OUTDISTRIBUTION == Props.TRAIN_JOINTBAYES_OUTDISTRIBUTION_TYPES.Y_GIVEN_ZSTAR) {
            Counters.normalize(classicCounter6);
        }
        return classicCounter6;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void save(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeObject(this.knownDependencies);
        objectOutputStream.writeObject(this.zLabelIndex);
        objectOutputStream.writeInt(this.zClassifiers.length);
        for (int i = 0; i < this.zClassifiers.length; i++) {
            objectOutputStream.writeObject(this.zClassifiers[i]);
        }
        objectOutputStream.writeObject(this.zSingleClassifier);
        objectOutputStream.writeInt(this.yClassifiers.keySet().size());
        for (String str : this.yClassifiers.keySet()) {
            objectOutputStream.writeObject(str);
            objectOutputStream.writeObject(this.yClassifiers.get(str));
        }
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        Redwood.Util.forceTrack("Loading Joint Bayes relation extractor (from input stream)");
        this.knownDependencies = (Set) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.zLabelIndex = (Index) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.numberOfFolds = objectInputStream.readInt();
        this.zClassifiers = (LinearClassifier[]) ErasureUtils.uncheckedCast(new LinearClassifier[this.numberOfFolds]);
        for (int i = 0; i < this.numberOfFolds; i++) {
            this.zClassifiers[i] = (LinearClassifier) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        }
        this.zSingleClassifier = (LinearClassifier) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        int readInt = objectInputStream.readInt();
        this.yClassifiers = new HashMap();
        for (int i2 = 0; i2 < readInt; i2++) {
            String str = (String) ErasureUtils.uncheckedCast(objectInputStream.readObject());
            LinearClassifier<String, String> linearClassifier = (LinearClassifier) ErasureUtils.uncheckedCast(objectInputStream.readObject());
            this.yClassifiers.put(str, linearClassifier);
            logger.debug(new Object[]{"Loaded Y classifier for label " + str + ": " + linearClassifier.toAllWeightsString()});
        }
        Redwood.Util.endTrack("Loading Joint Bayes relation extractor (from input stream)");
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof JointBayesRelationExtractor)) {
            return false;
        }
        JointBayesRelationExtractor jointBayesRelationExtractor = (JointBayesRelationExtractor) obj;
        if (this.numberOfFolds != jointBayesRelationExtractor.numberOfFolds || this.numberOfTrainEpochs != jointBayesRelationExtractor.numberOfTrainEpochs || this.onlyLocalTraining != jointBayesRelationExtractor.onlyLocalTraining || this.trainY != jointBayesRelationExtractor.trainY) {
            return false;
        }
        if (this.initialModelPath != null) {
            if (!this.initialModelPath.equals(jointBayesRelationExtractor.initialModelPath)) {
                return false;
            }
        } else if (jointBayesRelationExtractor.initialModelPath != null) {
            return false;
        }
        return this.serializedModelPath != null ? this.serializedModelPath.equals(jointBayesRelationExtractor.serializedModelPath) : jointBayesRelationExtractor.serializedModelPath == null;
    }

    public int hashCode() {
        return (31 * ((31 * ((31 * ((31 * ((31 * this.numberOfTrainEpochs) + this.numberOfFolds)) + (this.onlyLocalTraining ? 1 : 0))) + (this.initialModelPath != null ? this.initialModelPath.hashCode() : 0))) + (this.trainY ? 1 : 0))) + (this.serializedModelPath != null ? this.serializedModelPath.hashCode() : 0);
    }

    public static JointBayesRelationExtractor load(String str, Properties properties) throws IOException, ClassNotFoundException {
        return (JointBayesRelationExtractor) RelationClassifier.load(str, properties, JointBayesRelationExtractor.class);
    }

    static {
        $assertionsDisabled = !JointBayesRelationExtractor.class.desiredAssertionStatus();
        logger = Redwood.channels(new Object[]{"MIML-RE"});
        localClassificationMode = LOCAL_CLASSIFICATION_MODE.WEIGHTED_VOTE;
        ATLEASTONCE_FEAT = "atleastonce";
        NONE_FEAT = "none";
        UNIQUE_FEAT = "unique";
        SIGMOID_FEAT = "sigmoid";
        Y_FEATURES_FOR_INITIAL_MODEL = new ArrayList();
        Y_FEATURES_FOR_INITIAL_MODEL.add(NONE_FEAT);
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.ATLEAST_ONCE)) {
            Y_FEATURES_FOR_INITIAL_MODEL.add(ATLEASTONCE_FEAT);
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.UNIQUE)) {
            Y_FEATURES_FOR_INITIAL_MODEL.add(UNIQUE_FEAT);
        }
        if (Props.TRAIN_JOINTBAYES_YFEATURES.contains(Props.Y_FEATURE_CLASS.SIGMOID)) {
            Y_FEATURES_FOR_INITIAL_MODEL.add(SIGMOID_FEAT);
        }
    }
}
