package edu.stanford.nlp.kbp.slotfilling.evaluate.inference;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.kbp.common.CollectionUtils;
import edu.stanford.nlp.kbp.common.KBPEntity;
import edu.stanford.nlp.kbp.common.KBPSlotFill;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.NERTag;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.slotfilling.evaluate.EntityGraph;
import edu.stanford.nlp.kbp.slotfilling.evaluate.GoldResponseSet;
import edu.stanford.nlp.kbp.slotfilling.evaluate.inference.MLNText;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/inference/ProbabilisticGraphInferenceEngine.class */
public abstract class ProbabilisticGraphInferenceEngine extends GraphInferenceEngine {
    protected final GoldResponseSet goldResponses;
    protected final MLNText candidateRules;
    protected static final Redwood.RedwoodChannels logger = Redwood.channels(new Object[]{"ProbabilisticInference"});
    protected static final Set<MLNText.Predicate> kbpPredicates = new HashSet();

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/inference/ProbabilisticGraphInferenceEngine$RulesMode.class */
    public enum RulesMode {
        KBP_ONLY,
        REVERB_STRICT,
        REVERB
    }

    public ProbabilisticGraphInferenceEngine() throws IOException {
        this(GoldResponseSet.empty(), Props.TEST_GRAPH_INFERENCE_RULES_FILES, Props.TEST_GRAPH_INFERENCE_RULES_CUTOFF, Props.TEST_GRAPH_INFERENCE_DEPTH);
    }

    public ProbabilisticGraphInferenceEngine(GoldResponseSet goldResponseSet) throws IOException {
        this(goldResponseSet, Props.TEST_GRAPH_INFERENCE_RULES_FILES, Props.TEST_GRAPH_INFERENCE_RULES_CUTOFF, Props.TEST_GRAPH_INFERENCE_DEPTH);
    }

    public ProbabilisticGraphInferenceEngine(File file, double d, int i) {
        this((List<File>) Collections.singletonList(file), d, i);
    }

    public ProbabilisticGraphInferenceEngine(List<File> list, double d, int i) {
        this(GoldResponseSet.empty(), list, d, i);
    }

    protected ProbabilisticGraphInferenceEngine(GoldResponseSet goldResponseSet, List<File> list, double d, int i) {
        this.candidateRules = new MLNText();
        this.goldResponses = goldResponseSet;
        Redwood.Util.startTrack(new Object[]{"Loading candidate rules"});
        this.candidateRules.predicates.addAll(kbpPredicates);
        for (File file : list) {
            try {
                Redwood.Util.logf("Adding rules from %s", new Object[]{file.getPath()});
                this.candidateRules.mergeIn(MLNReader.parse(IOUtils.getBufferedReaderFromClasspathOrFileSystem(file.getPath())));
                Redwood.Util.logf("Currently have %d rules", new Object[]{Integer.valueOf(this.candidateRules.rules.size())});
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        ListIterator<MLNText.Rule> listIterator = this.candidateRules.rules.listIterator();
        while (listIterator.hasNext()) {
            MLNText.Rule next = listIterator.next();
            if (Math.abs(next.weight) < d) {
                listIterator.remove();
            } else if (next.literals.size() > i + 1) {
                listIterator.remove();
            } else if (Props.TEST_GRAPH_INFERENCE_HACKS_NO_SPOUSE) {
                if (untypedRelation(next.consequent().name).equals(RelationType.PER_SPOUSE.canonicalName) && next.antecedents().size() > 0) {
                    Redwood.Util.debug(new Object[]{"Skipping rule " + next + " because of spouse hack."});
                    listIterator.remove();
                }
            } else if (Props.TEST_GRAPH_INFERENCE_HACKS_NO_NEGATIVE_TRANSLATIONS && next.literals.size() == 2 && next.weight < 0.0d) {
                Redwood.Util.debug(new Object[]{"Skipping rule " + next + " because of negative translation hack."});
                listIterator.remove();
            }
        }
        Redwood.Util.endTrack("Loading candidate rules");
    }

    public Pair<MLNText, Map<String, KBPEntity>> graphToMLN(EntityGraph entityGraph, MLNText mLNText, KBPEntity kBPEntity) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (MLNText.Predicate predicate : mLNText.predicates) {
            if (!predicate.closed) {
                hashMap2.put(predicate.name, Pair.makePair(NERTag.fromString(predicate.type1).get(), NERTag.fromString(predicate.type2).get()));
            }
        }
        MLNText mLNText2 = new MLNText();
        for (KBPSlotFill kBPSlotFill : entityGraph.getAllEdges()) {
            KBPEntity entity = kBPSlotFill.key.getEntity();
            KBPEntity kBPEntity2 = kBPSlotFill.key.getSlotEntity().get();
            String cleanEntity = cleanEntity(entity);
            String cleanEntity2 = cleanEntity(kBPEntity2);
            hashMap.put(cleanEntity, entity);
            hashMap.put(cleanEntity2, kBPEntity2);
            HashMap hashMap3 = new HashMap(hashMap2);
            String cleanRelation = cleanRelation(kBPSlotFill.key.relationName, kBPSlotFill.key.entityType, kBPSlotFill.key.slotType.getOrElse(NERTag.MISC));
            hashMap3.remove(cleanRelation);
            if (!mLNText.getPredicateByName(cleanRelation).isNothing()) {
                MLNText.Predicate predicate2 = mLNText.getPredicateByName(cleanRelation).get();
                MLNText.Literal literal = new MLNText.Literal(true, predicate2.name, cleanEntity, cleanEntity2);
                mLNText2.predicates.add(predicate2);
                if (!predicate2.closed) {
                    mLNText2.predicates.add(predicate2);
                    if (kBPSlotFill.score.getOrElse(Double.valueOf(Double.POSITIVE_INFINITY)).isInfinite()) {
                        mLNText2.add(new MLNText.Rule(Double.POSITIVE_INFINITY, literal));
                    } else {
                        double doubleValue = kBPSlotFill.score.getOrElse(Double.valueOf(0.0d)).doubleValue();
                        double d = Props.TEST_GRAPH_INFERENCE_HACKS_ALWAYS_TRUE ? 1.0d : Props.TEST_GRAPH_INFERENCE_HACKS_SOFT_PRIORS ? 0.5d + (0.4d * doubleValue) : (1.0d + doubleValue) / 2.0d;
                        mLNText2.add(new MLNText.Rule(Math.log(d / (1.0d - d)), literal));
                    }
                } else if (Props.TEST_GRAPH_INFERENCE_HACKS_SOFT_PRIORS) {
                    mLNText2.add(new MLNText.Rule(kBPSlotFill.score.getOrElse(Double.valueOf(Double.POSITIVE_INFINITY)).doubleValue(), literal));
                } else {
                    mLNText2.add(new MLNText.Rule(Double.POSITIVE_INFINITY, literal));
                }
            }
        }
        return Pair.makePair(mLNText2, hashMap);
    }

    public List<MLNText.Literal> getQueryTerms(MLNText mLNText, KBPEntity kBPEntity) {
        HashSet hashSet = new HashSet(CollectionUtils.filter(kbpPredicates, predicate -> {
            return Boolean.valueOf(predicate.type1.equals(kBPEntity.type.name));
        }));
        hashSet.retainAll(mLNText.predicates);
        return CollectionUtils.map(hashSet, predicate2 -> {
            return new MLNText.Literal(true, predicate2.name, cleanEntity(kBPEntity), predicate2.type2.toLowerCase() + "1");
        });
    }

    public MLNText getRules(EntityGraph entityGraph, KBPEntity kBPEntity) {
        Redwood.Util.startTrack(new Object[]{"Filtering rules for entity"});
        HashMap hashMap = new HashMap();
        for (MLNText.Predicate predicate : this.candidateRules.predicates) {
            hashMap.put(predicate.name, predicate);
        }
        HashSet<String> hashSet = new HashSet(CollectionUtils.map(entityGraph.getAllEdges(), kBPSlotFill -> {
            return cleanRelation(kBPSlotFill);
        }));
        if (Props.TEST_GRAPH_INFERENCE_RULES_MODE == RulesMode.KBP_ONLY) {
            hashSet.retainAll(CollectionUtils.map(kbpPredicates, predicate2 -> {
                return predicate2.name;
            }));
        }
        HashSet<String> hashSet2 = new HashSet(CollectionUtils.filterMap(kbpPredicates, predicate3 -> {
            return predicate3.type1.equals(kBPEntity.type.name) ? Maybe.Just(predicate3.name) : Maybe.Nothing();
        }));
        ArrayList arrayList = new ArrayList(this.candidateRules.rules);
        ArrayList arrayList2 = new ArrayList();
        int size = arrayList2.size();
        for (int i = 0; i < 10; i++) {
            Redwood.Util.logf("(%d) %d rules (%d remain), with |V_c| = %d and |V_a| = %d", new Object[]{Integer.valueOf(i), Integer.valueOf(arrayList2.size()), Integer.valueOf(arrayList.size()), Integer.valueOf(hashSet2.size()), Integer.valueOf(hashSet.size())});
            ListIterator listIterator = arrayList.listIterator();
            while (listIterator.hasNext()) {
                MLNText.Rule rule = (MLNText.Rule) listIterator.next();
                if (Props.TEST_GRAPH_INFERENCE_HACKS_NO_SPOUSE && untypedRelation(rule.consequent().name).equals(RelationType.PER_SPOUSE.canonicalName) && rule.antecedents().size() > 0) {
                    Redwood.Util.debug(new Object[]{"Skipping rule " + rule + " because of spouse hack."});
                } else if (Props.TEST_GRAPH_INFERENCE_HACKS_NO_NEGATIVE_TRANSLATIONS && rule.literals.size() == 2 && rule.weight < 0.0d) {
                    Redwood.Util.debug(new Object[]{"Skipping rule " + rule + " because of negative translation hack."});
                } else if (CollectionUtils.all(rule.literals, literal -> {
                    return !literal.truth ? Boolean.valueOf(hashSet.contains(literal.name)) : Boolean.valueOf(hashSet2.contains(literal.name));
                })) {
                    listIterator.remove();
                    arrayList2.add(rule);
                }
            }
            if (Props.TEST_GRAPH_INFERENCE_RULES_MODE != RulesMode.REVERB || arrayList2.size() == size) {
                break;
            }
            size = arrayList2.size();
            hashSet.addAll(hashSet2);
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                for (MLNText.Literal literal2 : ((MLNText.Rule) it.next()).literals) {
                    if (!literal2.truth) {
                        hashSet2.add(literal2.name);
                    }
                }
            }
        }
        Redwood.Util.logf("# %d rules (%d remain), with |V_c| = %d and |V_a| = %d", new Object[]{Integer.valueOf(arrayList2.size()), Integer.valueOf(arrayList.size()), Integer.valueOf(hashSet2.size()), Integer.valueOf(hashSet.size())});
        hashSet2.clear();
        hashSet.clear();
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            for (MLNText.Literal literal3 : ((MLNText.Rule) it2.next()).literals) {
                if (literal3.truth) {
                    hashSet2.add(literal3.name);
                } else {
                    hashSet.add(literal3.name);
                }
            }
        }
        hashSet.addAll(CollectionUtils.map(entityGraph.getAllEdges(), kBPSlotFill2 -> {
            return cleanRelation(kBPSlotFill2);
        }));
        hashSet2.addAll(CollectionUtils.filterMap(entityGraph.getAllEdges(), kBPSlotFill3 -> {
            return (kBPSlotFill3.key.getEntity().equals(kBPEntity) && kBPSlotFill3.key.hasKBPRelation()) ? Maybe.Just(cleanRelation(kBPSlotFill3)) : Maybe.Nothing();
        }));
        MLNText mLNText = new MLNText();
        Redwood.Util.logf("# Creating %d rules with %d open and %d closed predicates.", new Object[]{Integer.valueOf(arrayList2.size()), Integer.valueOf(hashSet2.size()), Integer.valueOf(hashSet.size())});
        for (String str : hashSet2) {
            if (hashMap.get(str) != null) {
                mLNText.predicates.add(((MLNText.Predicate) hashMap.get(str)).asOpen());
            }
        }
        for (String str2 : hashSet) {
            if (!hashSet2.contains(str2) && hashMap.get(str2) != null) {
                mLNText.predicates.add(((MLNText.Predicate) hashMap.get(str2)).asClosed());
            }
        }
        mLNText.rules = arrayList2;
        Redwood.Util.endTrack("Filtering rules for entity");
        return mLNText;
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.inference.GraphInferenceEngine
    public abstract EntityGraph apply(EntityGraph entityGraph, KBPEntity kBPEntity);

    static {
        for (RelationType relationType : RelationType.values()) {
            for (NERTag nERTag : relationType.validNamedEntityLabels) {
                kbpPredicates.add(new MLNText.Predicate(cleanRelation(relationType.canonicalName, relationType.entityType, nERTag), relationType.entityType.name.toUpperCase(), nERTag.name.toUpperCase()));
            }
        }
    }
}
