package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPRelationProvenance;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.MetaClass;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/classify/RelationClassifier.class */
public abstract class RelationClassifier implements Serializable {
    protected Maybe<TrainingStatistics> statistics = Maybe.Nothing();

    public Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (RelationType relationType : RelationType.values()) {
            Pair<Double, Maybe<KBPRelationProvenance>> classifyRelation = classifyRelation(sentenceGroup, relationType, maybe);
            if (((Double) classifyRelation.first).doubleValue() > 0.0d) {
                classicCounter.setCount(Pair.makePair(relationType.canonicalName, classifyRelation.second), ((Double) classifyRelation.first).doubleValue());
            }
        }
        Counters.normalize(classicCounter);
        return classicCounter;
    }

    public Pair<Double, Maybe<KBPRelationProvenance>> classifyRelation(SentenceGroup sentenceGroup, RelationType relationType, Maybe<CoreMap[]> maybe) {
        Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations = classifyRelations(sentenceGroup, maybe);
        if (classifyRelations.getCount((Pair) Counters.argmax(classifyRelations)) == 0.0d) {
            return Pair.makePair(Double.valueOf(0.0d), Maybe.Nothing());
        }
        for (Pair pair : classifyRelations.keySet()) {
            if (((String) pair.first).equals(relationType.canonicalName)) {
                return Pair.makePair(Double.valueOf(classifyRelations.getCount(pair)), pair.second);
            }
        }
        return Pair.makePair(Double.valueOf(0.0d), Maybe.Nothing());
    }

    public Counter<String> classifyRelationsNoProvenance(SentenceGroup sentenceGroup, Maybe<CoreMap[]> maybe) {
        Counter<Pair<String, Maybe<KBPRelationProvenance>>> classifyRelations = classifyRelations(sentenceGroup, maybe);
        ClassicCounter classicCounter = new ClassicCounter();
        for (Map.Entry entry : classifyRelations.entrySet()) {
            classicCounter.setCount(((Pair) entry.getKey()).first, ((Double) entry.getValue()).doubleValue());
        }
        return classicCounter;
    }

    public abstract TrainingStatistics train(KBPDataset<String, String> kBPDataset);

    public abstract void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException;

    public abstract void save(ObjectOutputStream objectOutputStream) throws IOException;

    public void save(String str) throws IOException {
        int lastIndexOf = str.lastIndexOf(File.separator);
        if (lastIndexOf > 0) {
            File file = new File(str.substring(0, lastIndexOf));
            if (!file.exists()) {
                file.mkdirs();
            }
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
        save(objectOutputStream);
        objectOutputStream.close();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double softmax(double d, List<Double> list, double d2) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = d2 * list.get(i).doubleValue();
        }
        return Math.exp((d2 * d) - ArrayMath.logSum(dArr));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<Collection<String>> tupleToFeatureList(SentenceGroup sentenceGroup) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < sentenceGroup.size(); i++) {
            if (sentenceGroup.size() == 0) {
                Redwood.Util.warn(new Object[]{Redwood.Util.YELLOW, "Classifying a realtion with no supporting evidence"});
            }
            arrayList.add(sentenceGroup.get(i).asFeatures());
        }
        return arrayList;
    }

    protected static Counter<RelationType> asRelationTypeCounter(Counter<String> counter) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (String str : counter.keySet()) {
            classicCounter.setCount(RelationType.fromString(str).orCrash(str), counter.getCount(str));
        }
        return classicCounter;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static <E extends RelationClassifier> E load(String str, Properties properties, Class<E> cls) throws IOException, ClassNotFoundException {
        if (NOOPClassifier.class.isAssignableFrom(cls)) {
            return new NOOPClassifier();
        }
        Redwood.Util.startTrack(new Object[]{"Loading model [" + cls.getSimpleName() + "] from " + str});
        Redwood.Util.log(new Object[]{"opening input streams..."});
        InputStream inputStream = null;
        ObjectInputStream objectInputStream = null;
        try {
            inputStream = IOUtils.getInputStreamFromURLOrClasspathOrFileSystem(str);
            objectInputStream = new ObjectInputStream(inputStream);
        } catch (IOException e) {
            if (!HeuristicRelationExtractor.class.isAssignableFrom(cls)) {
                throw e;
            }
        }
        Redwood.Util.log(new Object[]{"constructing class via reflection..."});
        E e2 = (E) new MetaClass(cls).createInstance(new Object[]{properties});
        Redwood.Util.log(new Object[]{"calling Class' load() method..."});
        e2.load(objectInputStream);
        if (inputStream != null) {
            inputStream.close();
        }
        Redwood.Util.endTrack("Loading model [" + cls.getSimpleName() + "] from " + str);
        return e2;
    }

    public Maybe<TrainingStatistics> getTrainingStatistics() {
        return this.statistics;
    }

    public static Counter<Pair<String, Maybe<KBPRelationProvenance>>> firstProvenance(Counter<String> counter, SentenceGroup sentenceGroup) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (Map.Entry entry : counter.entrySet()) {
            classicCounter.setCount(Pair.makePair(entry.getKey(), Maybe.Just(sentenceGroup.getProvenance(0))), ((Double) entry.getValue()).doubleValue());
        }
        return classicCounter;
    }
}
