package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Pointer;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPRelationProvenance;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.logging.RedwoodConfiguration;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/EnsembleRelationExtractor.class */
public class EnsembleRelationExtractor extends RelationClassifier {
    protected static final Redwood.RedwoodChannels logger;
    private List<RelationClassifier> classifiers;
    private Properties properties;
    private EnsembleMethod method;
    private List<ModelType> modelTypes;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/EnsembleRelationExtractor$EnsembleMethod.class */
    public enum EnsembleMethod {
        DEFAULT,
        BAGGING,
        SUBAGGING
    }

    public EnsembleRelationExtractor(Properties properties) {
        this(properties, Props.TRAIN_ENSEMBLE_METHOD, Props.TRAIN_ENSEMBLE_COMPONENT, Props.TRAIN_ENSEMBLE_NUMCOMPONENTS);
    }

    public EnsembleRelationExtractor(Properties properties, EnsembleMethod ensembleMethod, ModelType modelType, int i) {
        this.classifiers = new ArrayList();
        this.properties = properties;
        this.method = ensembleMethod;
        this.modelTypes = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            this.modelTypes.add(modelType);
        }
    }

    public EnsembleRelationExtractor(Properties properties, EnsembleMethod ensembleMethod, List<ModelType> list) {
        this.classifiers = new ArrayList();
        this.properties = properties;
        this.method = ensembleMethod;
        this.modelTypes = list;
    }

    public EnsembleRelationExtractor(RelationClassifier... relationClassifierArr) {
        this.classifiers = new ArrayList(Arrays.asList(relationClassifierArr));
        this.method = Props.TRAIN_ENSEMBLE_METHOD;
        this.modelTypes = new ArrayList(this.classifiers.size());
        for (RelationClassifier relationClassifier : relationClassifierArr) {
            ModelType[] values = ModelType.values();
            int length = values.length;
            int i = 0;
            while (true) {
                if (i < length) {
                    ModelType modelType = values[i];
                    if (modelType.modelClass.isAssignableFrom(relationClassifier.getClass())) {
                        this.modelTypes.add(modelType);
                        break;
                    }
                    i++;
                }
            }
        }
        while (this.modelTypes.size() != relationClassifierArr.length) {
            logger.warn(new Object[]{"Could not find some model class"});
            this.modelTypes.add(ModelType.GOLD);
        }
        this.properties = new Properties();
    }

    public int numSamples() {
        return this.modelTypes.size();
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public TrainingStatistics train(KBPDataset<String, String> kBPDataset) {
        Pointer pointer = new Pointer();
        int size = this.modelTypes.size();
        boolean z = this.classifiers != null && size == this.classifiers.size();
        Redwood.Util.startTrack(new Object[]{"Generating samples"});
        List<KBPDataset<String, String>> generateSamples = generateSamples(kBPDataset, size);
        logger.log(new Object[]{"applying feature count thresholds (" + Props.FEATURE_COUNT_THRESHOLD + ")..."});
        Iterator<KBPDataset<String, String>> it = generateSamples.iterator();
        while (it.hasNext()) {
            it.next().applyFeatureCountThreshold(Props.FEATURE_COUNT_THRESHOLD);
        }
        Redwood.Util.endTrack("Generating samples");
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            File file = Props.KBP_MODEL_DIR;
            File file2 = new File(file.getPath() + File.separatorChar + "sample" + i);
            file2.mkdirs();
            Props.KBP_MODEL_DIR = file2;
            Redwood.Util.forceTrack("Creating classifier on subsample #" + i);
            RelationClassifier construct = z ? this.classifiers.get(i) : this.modelTypes.get(i).construct(this.properties);
            Props.KBP_MODEL_DIR = file;
            int i2 = i;
            arrayList.add(() -> {
                TrainingStatistics train = construct.train((KBPDataset) generateSamples.get(i2));
                synchronized (this) {
                    if (pointer.dereference().isDefined()) {
                        ((TrainingStatistics) pointer.dereference().get()).merge(train);
                    } else {
                        pointer.set((Pointer) train);
                    }
                }
            });
            this.classifiers.add(construct);
            Redwood.Util.endTrack("Creating classifier on subsample #" + i);
        }
        if (size > 1 && (Props.TRAIN_ENSEMBLE_COMPONENT != ModelType.JOINT_BAYES || Props.TRAIN_JOINTBAYES_MULTITHREAD)) {
            RedwoodConfiguration.current().restore(System.err).apply();
        }
        Redwood.Util.threadAndRun(arrayList, size);
        if (size > 1 && (Props.TRAIN_ENSEMBLE_COMPONENT != ModelType.JOINT_BAYES || Props.TRAIN_JOINTBAYES_MULTITHREAD)) {
            RedwoodConfiguration.current().capture(System.err).apply();
        }
        this.statistics = Maybe.Just(pointer.dereference().get());
        return (TrainingStatistics) pointer.dereference().get();
    }

    private List<KBPDataset<String, String>> generateSamples(KBPDataset<String, String> kBPDataset, int i) {
        switch (this.method) {
            case DEFAULT:
                return cloneData(kBPDataset, i);
            case BAGGING:
                return sampleData(kBPDataset, i);
            case SUBAGGING:
                return partitionData(kBPDataset, i);
            default:
                throw new RuntimeException("Unsupported ensemble method " + this.method.name() + ".");
        }
    }

    private List<KBPDataset<String, String>> cloneData(KBPDataset<String, String> kBPDataset, int i) {
        ArrayList arrayList = new ArrayList();
        Set<Integer>[] positiveLabelsArray = kBPDataset.getPositiveLabelsArray();
        Set<Integer>[] negativeLabelsArray = kBPDataset.getNegativeLabelsArray();
        Set<Integer>[] unknownLabelsArray = kBPDataset.getUnknownLabelsArray();
        int[][][] dataArray = kBPDataset.getDataArray();
        for (int i2 = 0; i2 < i; i2++) {
            KBPDataset kBPDataset2 = new KBPDataset(new HashIndex(kBPDataset.featureIndex()), new HashIndex(kBPDataset.labelIndex()));
            for (int i3 = 0; i3 < kBPDataset.size(); i3++) {
                kBPDataset2.addDatum(positiveLabelsArray[i3], negativeLabelsArray[i3], unknownLabelsArray[i3], dataArray[i3], kBPDataset.getSentenceGlossKey(i3), kBPDataset.getAnnotatedLabels(i3));
            }
            arrayList.add(kBPDataset2);
        }
        return arrayList;
    }

    private List<KBPDataset<String, String>> partitionData(KBPDataset<String, String> kBPDataset, int i) {
        ArrayList arrayList = new ArrayList();
        logger.log(new Object[]{"numSamples: " + i});
        Set<Integer>[] positiveLabelsArray = kBPDataset.getPositiveLabelsArray();
        Set<Integer>[] negativeLabelsArray = kBPDataset.getNegativeLabelsArray();
        Set<Integer>[] unknownLabelsArray = kBPDataset.getUnknownLabelsArray();
        int[][][] dataArray = kBPDataset.getDataArray();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new KBPDataset(new HashIndex(kBPDataset.featureIndex()), new HashIndex(kBPDataset.labelIndex())));
        }
        int size = kBPDataset.size();
        Integer[] numArr = new Integer[size];
        for (int i3 = 0; i3 < size; i3++) {
            numArr[i3] = Integer.valueOf(i3);
        }
        Collections.shuffle(Arrays.asList(numArr));
        for (int i4 = 0; i4 < size; i4++) {
            int intValue = numArr[i4].intValue();
            int i5 = i4 % i;
            if (!$assertionsDisabled && i5 >= arrayList.size()) {
                throw new AssertionError();
            }
            for (int i6 = 0; i6 < Props.TRAIN_ENSEMBLE_SUBAGREDUNDANCY; i6++) {
                ((KBPDataset) arrayList.get((i5 + i6) % arrayList.size())).addDatum(positiveLabelsArray[intValue], negativeLabelsArray[intValue], unknownLabelsArray[intValue], dataArray[intValue], kBPDataset.getSentenceGlossKey(intValue), kBPDataset.getAnnotatedLabels(intValue));
            }
        }
        return arrayList;
    }

    private List<KBPDataset<String, String>> sampleData(KBPDataset<String, String> kBPDataset, int i) {
        Set<Integer>[] positiveLabelsArray = kBPDataset.getPositiveLabelsArray();
        Set<Integer>[] negativeLabelsArray = kBPDataset.getNegativeLabelsArray();
        Set<Integer>[] unknownLabelsArray = kBPDataset.getUnknownLabelsArray();
        int[][][] dataArray = kBPDataset.getDataArray();
        ArrayList arrayList = new ArrayList();
        logger.log(new Object[]{"numSamples: " + i});
        int min = Math.min(Props.TRAIN_ENSEMBLE_BAGSIZE, kBPDataset.size());
        for (int i2 = 0; i2 < i; i2++) {
            Random random = new Random(i2);
            KBPDataset kBPDataset2 = new KBPDataset(new HashIndex(kBPDataset.featureIndex()), new HashIndex(kBPDataset.labelIndex()));
            for (int i3 = 0; i3 < min; i3++) {
                int nextInt = random.nextInt(min);
                kBPDataset2.addDatum(positiveLabelsArray[nextInt], negativeLabelsArray[nextInt], unknownLabelsArray[nextInt], dataArray[nextInt], kBPDataset.getSentenceGlossKey(nextInt), kBPDataset.getAnnotatedLabels(nextInt));
            }
            arrayList.add(kBPDataset2);
        }
        return arrayList;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        ClassicCounter classicCounter = new ClassicCounter();
        HashMap hashMap = new HashMap();
        HashSet<String> hashSet = new HashSet();
        ClassicCounter classicCounter2 = new ClassicCounter();
        ArrayList<Counter> arrayList = new ArrayList();
        for (RelationClassifier relationClassifier : this.classifiers) {
            ClassicCounter classicCounter3 = new ClassicCounter();
            for (Map.Entry entry : relationClassifier.classifyRelations(sentenceGroup, maybe).entrySet()) {
                classicCounter3.incrementCount(((Pair) entry.getKey()).first, ((Double) entry.getValue()).doubleValue());
                hashSet.add(((Pair) entry.getKey()).first);
                if (((Maybe) ((Pair) entry.getKey()).second).isDefined() && classicCounter2.getCount(((Pair) entry.getKey()).first) < ((Double) entry.getValue()).doubleValue()) {
                    hashMap.put(((Pair) entry.getKey()).first, ((Maybe) ((Pair) entry.getKey()).second).get());
                }
                classicCounter2.setCount(((Pair) entry.getKey()).first, Math.max(classicCounter2.getCount(((Pair) entry.getKey()).first), ((Double) entry.getValue()).doubleValue()));
            }
            arrayList.add(classicCounter3);
        }
        for (String str : hashSet) {
            int i = 0;
            boolean containsKey = ((Counter) arrayList.get(0)).containsKey(str);
            double count = ((Counter) arrayList.get(0)).getCount(str);
            double d = 1.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (Counter counter : arrayList) {
                if (counter.containsKey(str)) {
                    double count2 = counter.getCount(str);
                    i++;
                    d *= 1.0d - count2;
                    if (count2 > d2) {
                        d3 = d2;
                        d2 = count2;
                    } else if (count2 > d3) {
                        d3 = count2;
                    }
                }
            }
            switch (Props.TEST_ENSEMBLE_COMBINATION) {
                case AGREE_ANY:
                    if (i > 0) {
                        classicCounter.setCount(Pair.makePair(str, Maybe.fromNull(hashMap.get(str))), 1.0d - d);
                        break;
                    } else {
                        break;
                    }
                case AGREE_ALL:
                    if (i >= this.classifiers.size()) {
                        classicCounter.setCount(Pair.makePair(str, Maybe.fromNull(hashMap.get(str))), 1.0d - d);
                        break;
                    } else {
                        break;
                    }
                case AGREE_MOST:
                    if (i >= this.classifiers.size() / 2) {
                        classicCounter.setCount(Pair.makePair(str, Maybe.fromNull(hashMap.get(str))), 1.0d - d);
                        break;
                    } else {
                        break;
                    }
                case AGREE_TWO:
                    if (i >= 2) {
                        classicCounter.setCount(Pair.makePair(str, Maybe.fromNull(hashMap.get(str))), 1.0d - ((1.0d - d2) * (1.0d - d3)));
                        break;
                    } else {
                        break;
                    }
                case AGREE_FIRST:
                    if (containsKey) {
                        classicCounter.setCount(Pair.makePair(str, Maybe.fromNull(hashMap.get(str))), count);
                        break;
                    } else {
                        break;
                    }
                default:
                    throw new IllegalStateException("Unknown combination method for ensemble model: " + Props.TEST_ENSEMBLE_COMBINATION);
            }
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        int readInt = objectInputStream.readInt();
        this.method = (EnsembleMethod) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        this.modelTypes = new ArrayList();
        for (int i = 0; i < readInt; i++) {
            this.modelTypes.add((ModelType) ErasureUtils.uncheckedCast(objectInputStream.readObject()));
        }
        this.classifiers = new ArrayList();
        for (int i2 = 0; i2 < readInt; i2++) {
            RelationClassifier construct = this.modelTypes.get(i2).construct(this.properties);
            construct.load(objectInputStream);
            this.classifiers.add(construct);
        }
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier
    public void save(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(this.modelTypes.size());
        objectOutputStream.writeObject(this.method);
        Iterator<ModelType> it = this.modelTypes.iterator();
        while (it.hasNext()) {
            objectOutputStream.writeObject(it.next());
        }
        Iterator<RelationClassifier> it2 = this.classifiers.iterator();
        while (it2.hasNext()) {
            it2.next().save(objectOutputStream);
        }
    }

    static {
        $assertionsDisabled = !EnsembleRelationExtractor.class.desiredAssertionStatus();
        logger = Redwood.channels(new Object[]{"Ensemble"});
    }
}
