package edu.stanford.nlp.kbp.slotfilling.process;

import edu.stanford.nlp.kbp.common.CoreMapUtils;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.NERTag;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.RelationType;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.MetaClass;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/process/RelationFilter.class */
public class RelationFilter {
    protected static final Redwood.RedwoodChannels logger = Redwood.channels(new Object[]{"RelFilter"});
    private Function<Pair<SentenceGroup, Maybe<CoreMap[]>>, Counter<String>> classifier;
    private List<FilterComponent> filterComponents;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/process/RelationFilter$CorefFilterComponent.class */
    public static class CorefFilterComponent extends FilterComponent {
        @Override // edu.stanford.nlp.kbp.slotfilling.process.RelationFilter.FilterComponent
        public List<SentenceGroup> filter(List<SentenceGroup> list, Map<SentenceGroup, String> map, Map<SentenceGroup, Double> map2, List<SentenceGroup> list2) {
            ArrayList arrayList = new ArrayList();
            Map<String, List<SentenceGroup>> groupBySlotValue = groupBySlotValue(list);
            HashSet hashSet = new HashSet();
            Iterator<SentenceGroup> it = list2.iterator();
            while (it.hasNext()) {
                hashSet.add(it.next().key.getEntity());
            }
            Iterator<String> it2 = groupBySlotValue.keySet().iterator();
            while (it2.hasNext()) {
                double d = Double.NEGATIVE_INFINITY;
                for (SentenceGroup sentenceGroup : groupBySlotValue.get(it2.next())) {
                    ArrayList arrayList2 = new ArrayList();
                    if (hashSet.contains(sentenceGroup.key.getEntity())) {
                        double doubleValue = map2.get(sentenceGroup).doubleValue();
                        if (doubleValue == d) {
                            arrayList2.add(sentenceGroup);
                        } else if (doubleValue > d) {
                            d = doubleValue;
                            arrayList2 = new ArrayList();
                            arrayList2.add(sentenceGroup);
                        }
                    } else {
                        arrayList.add(sentenceGroup);
                    }
                    arrayList.addAll(arrayList2);
                }
            }
            if (Props.KBP_VERBOSE) {
                RelationFilter.logger.debug(new Object[]{"Entered CorefFilterComponent with " + list.size() + " sentence groups."});
                RelationFilter.logger.debug(new Object[]{"Exiting CorefFilterComponent with " + arrayList.size() + " sentence groups."});
                if (list.size() == arrayList.size()) {
                    RelationFilter.logger.debug(new Object[]{"CorefFilterComponent had no effect."});
                } else {
                    RelationFilter.logger.debug(new Object[]{"CorefFilterComponent caused reduction."});
                }
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/process/RelationFilter$CrossRelTypeCompetitionFilterComponent.class */
    public static class CrossRelTypeCompetitionFilterComponent extends FilterComponent {
        @Override // edu.stanford.nlp.kbp.slotfilling.process.RelationFilter.FilterComponent
        public List<SentenceGroup> filter(List<SentenceGroup> list, Map<SentenceGroup, String> map, Map<SentenceGroup, Double> map2, List<SentenceGroup> list2) {
            ArrayList arrayList = new ArrayList();
            Map<String, List<SentenceGroup>> groupBySlotValue = groupBySlotValue(list);
            Iterator<String> it = groupBySlotValue.keySet().iterator();
            while (it.hasNext()) {
                List<SentenceGroup> list3 = groupBySlotValue.get(it.next());
                ArrayList arrayList2 = new ArrayList();
                double d = Double.NEGATIVE_INFINITY;
                for (SentenceGroup sentenceGroup : list3) {
                    double doubleValue = map2.get(sentenceGroup).doubleValue();
                    if (doubleValue == d) {
                        arrayList2.add(sentenceGroup);
                    } else if (doubleValue > d) {
                        arrayList2 = new ArrayList();
                        arrayList2.add(sentenceGroup);
                        d = doubleValue;
                    }
                }
                arrayList.addAll(arrayList2);
            }
            if (Props.KBP_VERBOSE) {
                RelationFilter.logger.debug(new Object[]{"Entered CrossRelTypeCompetitionFilterComponent with " + list.size() + " sentence groups."});
                RelationFilter.logger.debug(new Object[]{"Exiting CrossRelTypeCompetitionFilterComponent with " + arrayList.size() + " sentence groups."});
                if (list.size() == arrayList.size()) {
                    RelationFilter.logger.debug(new Object[]{"CrossRelTypeCompetitionFilterComponent had no effect."});
                } else if (arrayList.size() < list.size()) {
                    RelationFilter.logger.debug(new Object[]{"CrossRelTypeCompetitionFilterComponent caused reduction."});
                } else {
                    RelationFilter.logger.debug(new Object[]{"WARNING: CrossRelTypeCompetitionFilterComponent caused increase."});
                }
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/process/RelationFilter$FilterComponent.class */
    public static abstract class FilterComponent {
        public abstract List<SentenceGroup> filter(List<SentenceGroup> list, Map<SentenceGroup, String> map, Map<SentenceGroup, Double> map2, List<SentenceGroup> list2);

        protected static Map<String, List<SentenceGroup>> groupBySlotValue(List<SentenceGroup> list) {
            HashMap hashMap = new HashMap();
            for (SentenceGroup sentenceGroup : list) {
                String str = sentenceGroup.key.slotValue;
                List list2 = (List) hashMap.get(str);
                if (list2 == null) {
                    list2 = new ArrayList();
                    hashMap.put(str, list2);
                }
                list2.add(sentenceGroup);
            }
            return hashMap;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/process/RelationFilter$PerRelTypeCompetitionFilterComponent.class */
    public static class PerRelTypeCompetitionFilterComponent extends FilterComponent {
        @Override // edu.stanford.nlp.kbp.slotfilling.process.RelationFilter.FilterComponent
        public List<SentenceGroup> filter(List<SentenceGroup> list, Map<SentenceGroup, String> map, Map<SentenceGroup, Double> map2, List<SentenceGroup> list2) {
            ArrayList arrayList = new ArrayList();
            Map<String, List<SentenceGroup>> groupBySlotValue = groupBySlotValue(list);
            Iterator<String> it = groupBySlotValue.keySet().iterator();
            while (it.hasNext()) {
                List<SentenceGroup> list3 = groupBySlotValue.get(it.next());
                HashMap hashMap = new HashMap();
                for (SentenceGroup sentenceGroup : list3) {
                    String str = map.get(sentenceGroup);
                    List list4 = (List) hashMap.get(str);
                    if (list4 == null) {
                        list4 = new ArrayList();
                        hashMap.put(str, list4);
                    }
                    list4.add(sentenceGroup);
                }
                for (String str2 : hashMap.keySet()) {
                    ArrayList arrayList2 = new ArrayList();
                    double d = Double.NEGATIVE_INFINITY;
                    for (SentenceGroup sentenceGroup2 : (List) hashMap.get(str2)) {
                        double doubleValue = map2.get(sentenceGroup2).doubleValue();
                        if (doubleValue == d) {
                            arrayList2.add(sentenceGroup2);
                        } else if (doubleValue > d) {
                            arrayList2 = new ArrayList();
                            arrayList2.add(sentenceGroup2);
                            d = doubleValue;
                        }
                    }
                    arrayList.addAll(arrayList2);
                }
            }
            if (Props.KBP_VERBOSE) {
                RelationFilter.logger.debug(new Object[]{"Entered PerRelTypeCompetitionFilterComponent with " + list.size() + " sentence groups."});
                RelationFilter.logger.debug(new Object[]{"Exiting PerRelTypeCompetitionFilterComponent with " + arrayList.size() + " sentence groups."});
                if (list.size() == arrayList.size()) {
                    RelationFilter.logger.debug(new Object[]{"PerRelTypeCompetitionFilterComponent had no effect."});
                } else {
                    RelationFilter.logger.debug(new Object[]{"PerRelTypeCompetitionFilterComponent caused reduction."});
                }
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/process/RelationFilter$RelationFilterBuilder.class */
    public static class RelationFilterBuilder {
        private Function<Pair<SentenceGroup, Maybe<CoreMap[]>>, Counter<String>> classifier;
        final List<FilterComponent> filterComponents = new ArrayList();
        static final /* synthetic */ boolean $assertionsDisabled;

        public RelationFilterBuilder(Function<Pair<SentenceGroup, Maybe<CoreMap[]>>, Counter<String>> function) {
            if (!$assertionsDisabled && function == null) {
                throw new AssertionError();
            }
            this.classifier = function;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void addFilterComponentByName(String str) {
            boolean z = -1;
            switch (str.hashCode()) {
                case 94848167:
                    if (str.equals("coref")) {
                        z = false;
                        break;
                    }
                    break;
                case 826842185:
                    if (str.equals("perRelTypeCompetition")) {
                        z = true;
                        break;
                    }
                    break;
                case 904863916:
                    if (str.equals("crossRelTypeCompetition")) {
                        z = 2;
                        break;
                    }
                    break;
            }
            switch (z) {
                case Unknown_VALUE:
                    str = CorefFilterComponent.class.getName();
                    break;
                case true:
                    str = PerRelTypeCompetitionFilterComponent.class.getName();
                    break;
                case true:
                    str = CrossRelTypeCompetitionFilterComponent.class.getName();
                    break;
            }
            try {
                addFilterComponent(Class.forName(str));
            } catch (ClassCastException e) {
                throw new RuntimeException("[RelationFilterBuilder.addFilter] Not a relation filter: " + str);
            } catch (ClassNotFoundException e2) {
                throw new RuntimeException("[RelationFilterBuilder.addFilter] Unknown name: " + str);
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void addFilterComponent(Class<FilterComponent> cls) {
            if (!$assertionsDisabled && cls == null) {
                throw new AssertionError();
            }
            this.filterComponents.add(MetaClass.create(cls).createInstance(new Object[0]));
        }

        public RelationFilter make() {
            if (this.classifier == null) {
                throw new RuntimeException("[RelationFilterBuilder.make()] Classifier not set.");
            }
            if (this.filterComponents.size() == 0) {
                throw new RuntimeException("[RelationFilterBuilder.make()] Must add at least one filter component.");
            }
            return new RelationFilter(this.classifier, this.filterComponents);
        }

        static {
            $assertionsDisabled = !RelationFilter.class.desiredAssertionStatus();
        }
    }

    private RelationFilter(Function<Pair<SentenceGroup, Maybe<CoreMap[]>>, Counter<String>> function, List<FilterComponent> list) {
        this.classifier = function;
        this.filterComponents = list;
    }

    public List<SentenceGroup> apply(List<SentenceGroup> list, List<SentenceGroup> list2, CoreMap coreMap) {
        Redwood.startTrack(new Object[]{"Applying filter..."});
        if (Props.KBP_VERBOSE) {
            logger.debug(new Object[]{"Applying filter...\n\tnSentenceGroupsForEntity = " + list.size() + "\n\tnSentenceGroupsForAllPairs = " + list2.size() + "\n\tsentence = " + CoreMapUtils.sentenceToMinimalString(coreMap)});
        }
        Pair<Map<SentenceGroup, String>, Map<SentenceGroup, Double>> predictLabels = predictLabels(list2, coreMap);
        Map<SentenceGroup, String> map = (Map) predictLabels.first();
        Map<SentenceGroup, Double> map2 = (Map) predictLabels.second();
        if (Props.KBP_VERBOSE) {
            logger.debug(new Object[]{"Label predictions..."});
        }
        for (SentenceGroup sentenceGroup : list2) {
            if (Props.KBP_VERBOSE) {
                logger.debug(new Object[]{"\tsentenceGroupKey = " + sentenceGroup.key + "\n\t\tprediction = " + map.get(sentenceGroup) + "\n\t\tscore = " + map2.get(sentenceGroup)});
            }
        }
        Iterator<FilterComponent> it = this.filterComponents.iterator();
        while (it.hasNext()) {
            list2 = it.next().filter(list2, map, map2, list);
        }
        HashSet hashSet = new HashSet();
        Iterator<SentenceGroup> it2 = list2.iterator();
        while (it2.hasNext()) {
            hashSet.add(it2.next());
        }
        ArrayList arrayList = new ArrayList();
        for (SentenceGroup sentenceGroup2 : list) {
            if (hashSet.contains(sentenceGroup2)) {
                arrayList.add(sentenceGroup2);
            }
        }
        if (Props.KBP_VERBOSE) {
            logger.debug(new Object[]{"SentenceGroupsForEntity..."});
            logger.debug(new Object[]{"\tBefore Filtering:"});
            Iterator<SentenceGroup> it3 = list.iterator();
            while (it3.hasNext()) {
                logger.debug(new Object[]{"\t\tkey = " + it3.next().key});
            }
            logger.debug(new Object[]{"\t\tTotal: " + list.size()});
            logger.debug(new Object[]{"\tAfter Filtering:"});
            Iterator it4 = arrayList.iterator();
            while (it4.hasNext()) {
                logger.debug(new Object[]{"\t\tkey = " + ((SentenceGroup) it4.next()).key});
            }
            logger.debug(new Object[]{"\t\tTotal: " + arrayList.size()});
            if (list.size() == arrayList.size()) {
                logger.debug(new Object[]{"RelationFilter had no effect."});
            } else {
                logger.debug(new Object[]{"RelationFilter caused reduction."});
            }
        }
        Redwood.endTrack("Applying filter...");
        return arrayList;
    }

    private Pair<Map<SentenceGroup, String>, Map<SentenceGroup, Double>> predictLabels(List<SentenceGroup> list, CoreMap coreMap) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (SentenceGroup sentenceGroup : list) {
            Counter<String> apply = this.classifier.apply(Pair.makePair(sentenceGroup, Maybe.Just(new CoreMap[]{coreMap})));
            String hardPrediction = getHardPrediction(apply, sentenceGroup.key.entityType, sentenceGroup.key.slotType);
            hashMap.put(sentenceGroup, hardPrediction);
            hashMap2.put(sentenceGroup, Double.valueOf(apply.getCount(hardPrediction)));
        }
        return Pair.makePair(hashMap, hashMap2);
    }

    private String getHardPrediction(Counter<String> counter, NERTag nERTag, Maybe<NERTag> maybe) {
        List sortedList = Counters.toSortedList(counter);
        String str = null;
        for (int i = 0; str == null && i < sortedList.size(); i++) {
            str = (String) sortedList.get(i);
            RelationType orCrash = RelationType.fromString(str).orCrash();
            if (nERTag != orCrash.entityType) {
                str = null;
            } else {
                Iterator<NERTag> it = maybe.iterator();
                while (it.hasNext()) {
                    if (!orCrash.validNamedEntityLabels.contains(it.next())) {
                        str = null;
                    }
                }
            }
        }
        if (str == null) {
            str = "_NR";
        }
        return str;
    }
}
