package edu.stanford.nlp.kbp.slotfilling.scripts;

import edu.stanford.nlp.kbp.common.KBPEntity;
import edu.stanford.nlp.kbp.common.KBPNew;
import edu.stanford.nlp.kbp.common.KBPOfficialEntity;
import edu.stanford.nlp.kbp.common.KBPSlotFill;
import edu.stanford.nlp.kbp.common.KBTriple;
import edu.stanford.nlp.kbp.common.Lazy;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.PostgresUtils;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.common.Utils;
import edu.stanford.nlp.kbp.entitylinking.WikidictEntityLinker;
import edu.stanford.nlp.kbp.slotfilling.SlotfillingSystem;
import edu.stanford.nlp.kbp.slotfilling.evaluate.EntityGraph;
import edu.stanford.nlp.kbp.slotfilling.evaluate.GoldResponseSet;
import edu.stanford.nlp.kbp.slotfilling.evaluate.GraphConsistencyPostProcessors;
import edu.stanford.nlp.kbp.slotfilling.evaluate.InferentialSlotFiller;
import edu.stanford.nlp.kbp.slotfilling.evaluate.SlotfillPostProcessor;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPIR;
import edu.stanford.nlp.kbp.slotfilling.ir.KnowledgeBase;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Function;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/scripts/MineInferentialPaths.class */
public class MineInferentialPaths {
    protected static final Redwood.RedwoodChannels logger;

    @Execution.Option(name = "mine-inferential-paths.begin", required = true, gloss = "The query entity to begin with")
    private static int begin;

    @Execution.Option(name = "mine-inferential-paths.count", required = true, gloss = "The number of entities to query over")
    private static int count;

    @Execution.Option(name = "mine-inferential-paths.cutoff", gloss = "Only take KBP relations not in the kB above this threshold")
    private static double cutoff;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/scripts/MineInferentialPaths$Direction.class */
    public enum Direction {
        FORWARD,
        BACKWARD
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/scripts/MineInferentialPaths$Trie.class */
    public static class Trie {
        public final KBPEntity entry;
        public final Map<Triple<String, Direction, KBPEntity>, Trie> children;
        public final Pair<String, Direction> relationFromParent;
        public final Trie parent;

        /* JADX INFO: Access modifiers changed from: protected */
        public Trie(KBPEntity kBPEntity) {
            this.entry = kBPEntity;
            this.children = new HashMap();
            this.relationFromParent = null;
            this.parent = null;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Trie(Triple<String, Direction, KBPEntity> triple, Trie trie) {
            this.entry = (KBPEntity) triple.third;
            this.children = new HashMap();
            this.relationFromParent = Pair.makePair(triple.first, triple.second);
            this.parent = trie;
        }

        public int depth() {
            if (this.parent == null) {
                return 0;
            }
            return 1 + this.parent.depth();
        }

        public KBPEntity root() {
            return this.parent == null ? this.entry : this.parent.root();
        }

        public boolean isLoop() {
            return this.parent != null && root().equals(this.entry);
        }

        public boolean danglingLoop(KBPEntity kBPEntity) {
            return this.parent != null && (this.entry.equals(kBPEntity) || this.parent.danglingLoop(kBPEntity));
        }

        public Maybe<Trie> extend(KBTriple kBTriple) {
            Triple<String, Direction, KBPEntity> makeTriple;
            KBPEntity entity = kBTriple.getEntity();
            KBPEntity orCrash = kBTriple.getSlotEntity().orCrash();
            if (entity.equals(this.entry) || (!orCrash.equals(this.entry) && entity.name.equals(this.entry.name))) {
                if (danglingLoop(orCrash)) {
                    return Maybe.Nothing();
                }
                if (this.parent != null && this.parent.entry.equals(orCrash) && ((String) this.relationFromParent.first).equals(kBTriple.relationName) && this.relationFromParent.second == Direction.BACKWARD) {
                    return Maybe.Nothing();
                }
                makeTriple = Triple.makeTriple(kBTriple.relationName, Direction.FORWARD, orCrash);
            } else {
                if (!orCrash.equals(this.entry) && !orCrash.name.equals(this.entry.name)) {
                    throw new IllegalStateException("Cannot add edge to Trie: " + kBTriple + "; Trie ends at " + this.entry);
                }
                if (danglingLoop(entity)) {
                    return Maybe.Nothing();
                }
                if (this.parent != null && this.parent.entry.equals(entity) && ((String) this.relationFromParent.first).equals(kBTriple.relationName) && this.relationFromParent.second == Direction.FORWARD) {
                    return Maybe.Nothing();
                }
                makeTriple = Triple.makeTriple(kBTriple.relationName, Direction.BACKWARD, kBTriple.getEntity());
            }
            if (this.children.containsKey(makeTriple)) {
                return Maybe.Nothing();
            }
            Trie trie = new Trie(makeTriple, this);
            if (trie.isLoop()) {
                this.children.put(makeTriple, trie);
                return Maybe.Nothing();
            }
            if (trie.depth() > Props.TEST_GRAPH_INFERENCE_DEPTH) {
                return Maybe.Nothing();
            }
            this.children.put(makeTriple, trie);
            return Maybe.Just(trie);
        }

        public List<KBTriple> asPath() {
            if (this.parent == null) {
                return new ArrayList();
            }
            List<KBTriple> asPath = this.parent.asPath();
            KBPEntity kBPEntity = this.relationFromParent.second == Direction.FORWARD ? this.parent.entry : this.entry;
            asPath.add(KBPNew.from(kBPEntity).slotValue(this.relationFromParent.second == Direction.BACKWARD ? this.parent.entry : this.entry).rel((String) this.relationFromParent.first).KBTriple());
            return asPath;
        }

        public Collection<List<KBTriple>> allPathsInTrie() {
            ArrayList arrayList = new ArrayList();
            if (this.parent != null) {
                arrayList.add(asPath());
            }
            Iterator<Trie> it = this.children.values().iterator();
            while (it.hasNext()) {
                arrayList.addAll(it.next().allPathsInTrie());
            }
            return arrayList;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof Trie)) {
                return false;
            }
            Trie trie = (Trie) obj;
            return this.entry.equals(trie.entry) && this.parent.equals(trie.parent) && this.relationFromParent.equals(trie.relationFromParent);
        }

        public int hashCode() {
            return (31 * this.entry.hashCode()) + this.relationFromParent.hashCode();
        }

        private String toString(String str) {
            String str2 = this.entry + "\n";
            for (Trie trie : this.children.values()) {
                str2 = str2 + str + (trie.relationFromParent.second == Direction.FORWARD ? "  -[" + ((String) trie.relationFromParent.first) + "]-> " : "  <-[" + ((String) trie.relationFromParent.first) + "]- ") + trie.toString(str + "    ");
            }
            return str2;
        }

        public String toString() {
            return toString("");
        }
    }

    public static void main(String[] strArr) {
        SlotfillingSystem.exec((Function<Properties, Object>) properties -> {
            try {
                Props.ENTITYLINKING_LINKER = Lazy.from(new WikidictEntityLinker());
                SlotfillingSystem slotfillingSystem = new SlotfillingSystem(properties);
                final InferentialSlotFiller inferentialSlotFiller = new InferentialSlotFiller(properties, slotfillingSystem.getIR(), slotfillingSystem.getProcess(), slotfillingSystem.getTrainedClassifier().get(), new GoldResponseSet());
                final KBPIR ir = slotfillingSystem.getIR();
                final KnowledgeBase knowledgeBase = slotfillingSystem.getIR().getKnowledgeBase();
                TreeSet treeSet = new TreeSet();
                Iterator<KBTriple> it = knowledgeBase.triples().iterator();
                while (it.hasNext()) {
                    treeSet.add(KBPNew.from(it.next().getEntity()).KBPOfficialEntity());
                }
                final KBPOfficialEntity[] kBPOfficialEntityArr = (KBPOfficialEntity[]) treeSet.toArray(new KBPOfficialEntity[treeSet.size()]);
                logger.log(new Object[]{Redwood.Util.BLUE, "" + kBPOfficialEntityArr.length + " entities in KB"});
                logger.log(new Object[]{Redwood.Util.BLUE, "querying [" + begin + ", " + Math.min(begin + count, kBPOfficialEntityArr.length) + "]"});
                PostgresUtils.withSet("mined_documents", new PostgresUtils.SetCallback() { // from class: edu.stanford.nlp.kbp.slotfilling.scripts.MineInferentialPaths.1
                    @Override // edu.stanford.nlp.kbp.common.PostgresUtils.Callback
                    public void apply(final Connection connection) throws SQLException {
                        try {
                            for (int i = MineInferentialPaths.begin; i < Math.min(MineInferentialPaths.begin + MineInferentialPaths.count, kBPOfficialEntityArr.length); i++) {
                                KBPOfficialEntity kBPOfficialEntity = kBPOfficialEntityArr[i];
                                Redwood.Util.forceTrack("Mining paths for " + kBPOfficialEntity);
                                try {
                                    try {
                                        final HashSet hashSet = new HashSet();
                                        MineInferentialPaths.runOnGraph(MineInferentialPaths.enforceKBInGraph(inferentialSlotFiller.extractRelationGraph(kBPOfficialEntity, Props.TEST_SENTENCES_PER_ENTITY, Maybe.Just(new Function<String, Boolean>() { // from class: edu.stanford.nlp.kbp.slotfilling.scripts.MineInferentialPaths.1.1
                                            @Override // java.util.function.Function
                                            public Boolean apply(String str) {
                                                try {
                                                } catch (SQLException e) {
                                                    MineInferentialPaths.logger.err(new Object[]{e});
                                                }
                                                if (contains(connection, "mined_documents", str)) {
                                                    return false;
                                                }
                                                hashSet.add(str);
                                                return true;
                                            }
                                        })), knowledgeBase), ir);
                                        Iterator it2 = hashSet.iterator();
                                        while (it2.hasNext()) {
                                            add(connection, "mined_documents", (String) it2.next());
                                        }
                                        flush(connection, "mined_documents");
                                        Redwood.Util.endTracksUntil("Mining paths for " + kBPOfficialEntity);
                                    } catch (Throwable th) {
                                        MineInferentialPaths.logger.err(new Object[]{th});
                                        Redwood.Util.endTracksUntil("Mining paths for " + kBPOfficialEntity);
                                    }
                                    Redwood.Util.endTrack("Mining paths for " + kBPOfficialEntity);
                                } catch (Throwable th2) {
                                    Redwood.Util.endTracksUntil("Mining paths for " + kBPOfficialEntity);
                                    throw th2;
                                }
                            }
                        } catch (Exception e) {
                            MineInferentialPaths.logger.err(new Object[]{e});
                        }
                    }
                });
                return null;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }, strArr);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static EntityGraph enforceKBInGraph(EntityGraph entityGraph, KnowledgeBase knowledgeBase) {
        Redwood.Util.startTrack(new Object[]{"Syncing with KB"});
        Iterator edgeIterator = entityGraph.edgeIterator();
        while (edgeIterator.hasNext()) {
            KBPSlotFill kBPSlotFill = (KBPSlotFill) edgeIterator.next();
            KBPEntity entity = kBPSlotFill.key.getEntity();
            if (knowledgeBase.data.containsKey(entity)) {
                Iterator<KBPSlotFill> it = knowledgeBase.data.get(entity).iterator();
                while (true) {
                    if (it.hasNext()) {
                        KBPSlotFill next = it.next();
                        if (next.key.slotValue.equals(kBPSlotFill.key.slotValue)) {
                            Iterator<RelationType> it2 = kBPSlotFill.key.tryKbpRelation().iterator();
                            while (it2.hasNext()) {
                                RelationType next2 = it2.next();
                                Iterator<RelationType> it3 = next.key.tryKbpRelation().iterator();
                                while (it3.hasNext()) {
                                    if (!it3.next().plausiblyCooccursWith(next2)) {
                                        logger.log(new Object[]{"Filtered impossible relation: " + next2 + " on account of " + next2});
                                        edgeIterator.remove();
                                        break;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        HashMap hashMap = new HashMap();
        for (KBPEntity kBPEntity : entityGraph.getAllVertices()) {
            hashMap.put(kBPEntity.name, kBPEntity);
        }
        if (cutoff > 0.0d) {
            Iterator it4 = entityGraph.edgeIterable().iterator();
            while (it4.hasNext()) {
                if (((KBPSlotFill) it4.next()).score.getOrElse(Double.valueOf(1.0d)).doubleValue() < cutoff) {
                    it4.remove();
                }
            }
        }
        for (KBPEntity kBPEntity2 : knowledgeBase.data.keySet()) {
            if (entityGraph.containsVertex(kBPEntity2)) {
                Iterator<KBPSlotFill> it5 = knowledgeBase.data.get(kBPEntity2).iterator();
                while (it5.hasNext()) {
                    KBPSlotFill next3 = it5.next();
                    KBPEntity kBPEntity3 = (KBPEntity) hashMap.get(next3.key.slotValue);
                    if (kBPEntity3 != null) {
                        KBPSlotFill KBPSlotFill = KBPNew.from(next3).slotValue(kBPEntity3).score(Double.valueOf(1.0d)).KBPSlotFill();
                        if (entityGraph.containsVertex(kBPEntity3) && !entityGraph.getEdges(kBPEntity2, kBPEntity3).contains(KBPSlotFill)) {
                            logger.log(new Object[]{"Adding known relation: " + KBPSlotFill});
                            entityGraph.add(kBPEntity2, kBPEntity3, KBPSlotFill);
                        }
                    }
                }
            }
        }
        Redwood.Util.endTrack("Syncing with KB");
        return entityGraph;
    }

    public static void runOnGraph(EntityGraph entityGraph, KBPIR kbpir) {
        Counter<List<KBTriple>> extractAllFormulas = extractAllFormulas(new GraphConsistencyPostProcessors.UnaryConsistencyPostProcessor(SlotfillPostProcessor.unary).postProcess(entityGraph), kbpir);
        logger.log(new Object[]{Redwood.Util.GREEN, "" + extractAllFormulas.size() + " formulas extracted"});
        saveInferentialPaths(extractAllFormulas);
        if (extractAllFormulas.size() <= 100) {
            Redwood.Util.startTrack(new Object[]{"Formulas"});
            for (Map.Entry entry : extractAllFormulas.entrySet()) {
                logger.log(new Object[]{entry.getValue() + ": " + StringUtils.join((Iterable) entry.getKey(), " ∧ ")});
            }
            Redwood.Util.endTrack("Formulas");
        }
    }

    private static void saveInferentialPaths(final Counter<List<KBTriple>> counter) {
        PostgresUtils.withCounter(Props.DB_TABLE_MINED_FORMULAS, new PostgresUtils.CNFFormulaCounterCallback() { // from class: edu.stanford.nlp.kbp.slotfilling.scripts.MineInferentialPaths.2
            static final /* synthetic */ boolean $assertionsDisabled;

            @Override // edu.stanford.nlp.kbp.common.PostgresUtils.Callback
            public void apply(Connection connection) throws SQLException {
                for (Map.Entry entry : counter.entrySet()) {
                    if (Utils.doesLoop((Collection) entry.getKey())) {
                        if (!$assertionsDisabled && !Utils.doesLoop(string2key(key2string((Collection<KBTriple>) entry.getKey())))) {
                            throw new AssertionError();
                        }
                        List subList = ((List) entry.getKey()).subList(0, ((List) entry.getKey()).size() - 1);
                        if (!$assertionsDisabled && counter.getCount(string2key(key2string((Collection<KBTriple>) subList))) <= 0.0d) {
                            throw new AssertionError();
                        }
                    }
                    incrementCount(connection, Props.DB_TABLE_MINED_FORMULAS, entry.getKey(), ((Double) entry.getValue()).doubleValue());
                }
                flush(connection, Props.DB_TABLE_MINED_FORMULAS);
            }

            static {
                $assertionsDisabled = !MineInferentialPaths.class.desiredAssertionStatus();
            }
        });
    }

    public static Counter<List<KBTriple>> extractAllFormulas(EntityGraph entityGraph, KBPIR kbpir) {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        LinkedList linkedList = new LinkedList();
        for (KBPEntity kBPEntity : entityGraph.getAllVertices()) {
            Trie trie = new Trie(kBPEntity);
            hashMap.put(kBPEntity, trie);
            linkedList.add(trie);
        }
        while (!linkedList.isEmpty()) {
            Trie trie2 = (Trie) linkedList.poll();
            Iterator it = entityGraph.outgoingEdgeIterable(trie2.entry).iterator();
            while (it.hasNext()) {
                Iterator<Trie> it2 = trie2.extend(((KBPSlotFill) it.next()).key).iterator();
                while (it2.hasNext()) {
                    linkedList.add(it2.next());
                }
            }
            Iterator it3 = entityGraph.incomingEdgeIterable(trie2.entry).iterator();
            while (it3.hasNext()) {
                Iterator<Trie> it4 = trie2.extend(((KBPSlotFill) it3.next()).key).iterator();
                while (it4.hasNext()) {
                    linkedList.add(it4.next());
                }
            }
        }
        Iterator it5 = hashMap.values().iterator();
        while (it5.hasNext()) {
            arrayList.addAll(((Trie) it5.next()).allPathsInTrie());
        }
        int i = Execution.threads;
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList2.add(new ArrayList());
        }
        int i3 = 0;
        Iterator it6 = arrayList.iterator();
        while (it6.hasNext()) {
            ((List) arrayList2.get(i3 % i)).add((List) it6.next());
            i3++;
        }
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i4;
            arrayList4.add(new ClassicCounter());
            arrayList3.add(() -> {
                Collection arrayList5;
                new HashMap();
                for (List list : (List) arrayList2.get(i5)) {
                    if (Utils.doesLoop(list)) {
                        arrayList5 = Utils.normalizeEntailment(list);
                        if (!$assertionsDisabled && !Utils.doesLoop(arrayList5)) {
                            throw new AssertionError();
                        }
                    } else {
                        arrayList5 = new ArrayList(Utils.normalizeConjunction(list));
                    }
                    ((Counter) arrayList4.get(i5)).incrementCount(arrayList5, 0.5d / 1.0d);
                }
            });
        }
        Redwood.Util.threadAndRun(arrayList3);
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator it7 = arrayList4.iterator();
        while (it7.hasNext()) {
            Counters.addInPlace(classicCounter, (Counter) it7.next());
        }
        return classicCounter;
    }

    private static double numDocumentsContaining(Collection<KBTriple> collection, KBPIR kbpir, Map<Set<String>, Double> map) {
        Double valueOf;
        HashSet hashSet = new HashSet();
        for (KBTriple kBTriple : collection) {
            hashSet.add(kBTriple.entityName);
            hashSet.add(kBTriple.slotValue);
        }
        Double d = map.get(hashSet);
        if (d != null) {
            return d.doubleValue();
        }
        try {
            valueOf = Double.valueOf(Math.max(1.0d, kbpir.queryNumHits(hashSet)));
        } catch (Exception e) {
            logger.err(new Object[]{e});
            valueOf = Double.valueOf(1.0d);
        }
        map.put(hashSet, valueOf);
        return valueOf.doubleValue();
    }

    static {
        $assertionsDisabled = !MineInferentialPaths.class.desiredAssertionStatus();
        logger = Redwood.channels(new Object[]{"Miner"});
        begin = 0;
        count = Integer.MAX_VALUE;
        cutoff = 0.0d;
    }
}
