package edu.stanford.nlp.kbp.slotfilling.evaluate;

import edu.stanford.nlp.kbp.common.CollectionUtils;
import edu.stanford.nlp.kbp.common.KBPEntity;
import edu.stanford.nlp.kbp.common.KBPSlotFill;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.IdentityHashSet;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.function.Function;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/HeuristicSlotfillPostProcessor.class */
public abstract class HeuristicSlotfillPostProcessor extends SlotfillPostProcessor {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/HeuristicSlotfillPostProcessor$Default.class */
    public static class Default extends HeuristicSlotfillPostProcessor {
        @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor
        public Maybe<KBPSlotFill> isValidSlotAndRewrite(KBPEntity kBPEntity, KBPSlotFill kBPSlotFill) {
            return Maybe.Just(kBPSlotFill);
        }

        @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor
        public boolean pairwiseKeepLowerScoringFill(KBPEntity kBPEntity, KBPSlotFill kBPSlotFill, KBPSlotFill kBPSlotFill2) {
            return true;
        }

        @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor
        public boolean leaveOneOutKeepHeldOutSlot(KBPEntity kBPEntity, IdentityHashSet<KBPSlotFill> identityHashSet, KBPSlotFill kBPSlotFill) {
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/HeuristicSlotfillPostProcessor$GibbsState.class */
    public static class GibbsState {
        private final boolean[] slotsActive;
        private final boolean savedDeactivatedValue;
        private final boolean savedActivatedValue;
        private final boolean savedActivated2Value;
        private final int toDeactivate;
        private final int toActivate;
        private final int toActivate2;
        private boolean isRestored = false;

        public GibbsState(boolean[] zArr, int i, int i2, int i3) {
            this.slotsActive = zArr;
            this.toDeactivate = i;
            this.toActivate = i2;
            this.toActivate2 = i3;
            this.savedDeactivatedValue = zArr[i];
            this.savedActivatedValue = zArr[i2];
            this.savedActivated2Value = zArr[i3];
            zArr[i] = false;
            zArr[i2] = true;
            zArr[i3] = true;
        }

        public <E> E restoreAndReturn(E e) {
            if (this.isRestored) {
                throw new IllegalStateException("Using a Gibbs state twice!");
            }
            this.slotsActive[this.toDeactivate] = this.savedDeactivatedValue;
            this.slotsActive[this.toActivate] = this.savedActivatedValue;
            this.slotsActive[this.toActivate2] = this.savedActivated2Value;
            this.isRestored = true;
            return e;
        }
    }

    private boolean blockGibbsCanTransition(KBPEntity kBPEntity, KBPSlotFill[] kBPSlotFillArr, GibbsState gibbsState) {
        return ((Boolean) gibbsState.restoreAndReturn(Boolean.valueOf(isConsistent(kBPEntity, kBPSlotFillArr, gibbsState.slotsActive)))).booleanValue();
    }

    private boolean isConsistent(KBPEntity kBPEntity, KBPSlotFill[] kBPSlotFillArr, boolean[] zArr) {
        for (int i = 0; i < kBPSlotFillArr.length; i++) {
            if (zArr[i] && !isValidSlotAndRewrite(kBPEntity, kBPSlotFillArr[i]).isDefined()) {
                return false;
            }
        }
        for (int i2 = 0; i2 < kBPSlotFillArr.length; i2++) {
            for (int i3 = i2 + 1; i3 < kBPSlotFillArr.length; i3++) {
                if (zArr[i2] && zArr[i3] && !pairwiseKeepLowerScoringFill(kBPEntity, kBPSlotFillArr[i2], kBPSlotFillArr[i3])) {
                    return false;
                }
            }
        }
        IdentityHashSet<KBPSlotFill> identityHashSet = new IdentityHashSet<>();
        for (int i4 = 0; i4 < kBPSlotFillArr.length; i4++) {
            if (zArr[i4]) {
                identityHashSet.add(kBPSlotFillArr[i4]);
            }
        }
        for (int i5 = 0; i5 < kBPSlotFillArr.length; i5++) {
            if (zArr[i5]) {
                identityHashSet.remove(kBPSlotFillArr[i5]);
                if (!leaveOneOutKeepHeldOutSlot(kBPEntity, identityHashSet, kBPSlotFillArr[i5])) {
                    return false;
                }
                identityHashSet.add(kBPSlotFillArr[i5]);
            }
        }
        return true;
    }

    private int greedyEnableSlotsInPlace(KBPEntity kBPEntity, KBPSlotFill[] kBPSlotFillArr, boolean[] zArr) {
        int i = 0;
        for (int i2 = 0; i2 < kBPSlotFillArr.length; i2++) {
            if (blockGibbsCanTransition(kBPEntity, kBPSlotFillArr, new GibbsState(zArr, i2, i2, i2))) {
                zArr[i2] = true;
                i++;
            } else if (!$assertionsDisabled && zArr[i2]) {
                throw new AssertionError();
            }
        }
        return i;
    }

    private List<KBPSlotFill> filterStep(KBPEntity kBPEntity, List<KBPSlotFill> list, GoldResponseSet goldResponseSet) {
        ArrayList arrayList = new ArrayList();
        for (KBPSlotFill kBPSlotFill : list) {
            Maybe<KBPSlotFill> isValidSlotAndRewrite = isValidSlotAndRewrite(kBPEntity, kBPSlotFill);
            Iterator<KBPSlotFill> it = isValidSlotAndRewrite.iterator();
            while (it.hasNext()) {
                KBPSlotFill next = it.next();
                if (!next.equals(kBPSlotFill)) {
                    goldResponseSet.discardRewritten(kBPSlotFill);
                    goldResponseSet.registerResponse(next);
                }
                arrayList.add(Props.TEST_CONSISTENCY_REWRITE ? next : kBPSlotFill);
            }
            if (!isValidSlotAndRewrite.isDefined()) {
                goldResponseSet.discardInconsistent(kBPSlotFill);
            }
        }
        KBPSlotFill[] kBPSlotFillArr = (KBPSlotFill[]) arrayList.toArray(new KBPSlotFill[arrayList.size()]);
        Arrays.sort(kBPSlotFillArr);
        boolean[] zArr = new boolean[kBPSlotFillArr.length];
        greedyEnableSlotsInPlace(kBPEntity, kBPSlotFillArr, zArr);
        if (!$assertionsDisabled && !isConsistent(kBPEntity, kBPSlotFillArr, zArr)) {
            throw new AssertionError();
        }
        if (Props.TEST_CONSISTENCY_GIBBSOBJECTIVE != Props.GibbsObjective.TOP) {
            Function<Pair<boolean[], KBPSlotFill[]>, Double> objective = getObjective(Props.TEST_CONSISTENCY_GIBBSOBJECTIVE);
            boolean[] zArr2 = new boolean[zArr.length];
            System.arraycopy(zArr, 0, zArr2, 0, zArr.length);
            double doubleValue = objective.apply(Pair.makePair(zArr, kBPSlotFillArr)).doubleValue();
            Random random = new Random(42L);
            int[] seq = CollectionUtils.seq(zArr.length);
            Redwood.Util.log(new Object[]{"initial objective: " + doubleValue});
            for (int i = 0; i < Props.TEST_CONSISTENCY_MIXINGTIME; i++) {
                Arrays.fill(zArr, false);
                ArrayMath.shuffle(seq, random);
                for (int i2 : seq) {
                    if (blockGibbsCanTransition(kBPEntity, kBPSlotFillArr, new GibbsState(zArr, i2, i2, i2))) {
                        zArr[i2] = true;
                    }
                }
                double doubleValue2 = objective.apply(Pair.makePair(zArr, kBPSlotFillArr)).doubleValue();
                if (doubleValue2 > doubleValue) {
                    doubleValue = doubleValue2;
                    System.arraycopy(zArr, 0, zArr2, 0, zArr.length);
                    Redwood.Util.log(new Object[]{"found higher objective: " + doubleValue});
                }
            }
            zArr = zArr2;
            int i3 = 0;
            for (boolean z : zArr) {
                i3 += z ? 1 : 0;
            }
        }
        if (!$assertionsDisabled && !isConsistent(kBPEntity, kBPSlotFillArr, zArr)) {
            throw new AssertionError();
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i4 = 0; i4 < kBPSlotFillArr.length; i4++) {
            if (zArr[i4]) {
                arrayList2.add(kBPSlotFillArr[i4]);
            } else {
                goldResponseSet.discardInconsistent(kBPSlotFillArr[i4]);
            }
        }
        return arrayList2;
    }

    private Function<Pair<boolean[], KBPSlotFill[]>, Double> getObjective(Props.GibbsObjective gibbsObjective) {
        switch (gibbsObjective) {
            case TOP:
                return pair -> {
                    throw new IllegalStateException("No well defined objective for GibbsObjective.TOP");
                };
            case SUM:
                return pair2 -> {
                    boolean[] zArr = (boolean[]) pair2.first;
                    KBPSlotFill[] kBPSlotFillArr = (KBPSlotFill[]) pair2.second;
                    double d = 0.0d;
                    for (int i = 0; i < zArr.length; i++) {
                        if (zArr[i]) {
                            d += kBPSlotFillArr[i].score.getOrElse(Double.valueOf(0.0d)).doubleValue();
                        }
                    }
                    return Double.valueOf(d);
                };
            default:
                throw new IllegalArgumentException("Objective type not implemented: " + gibbsObjective);
        }
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.SlotfillPostProcessor
    public SlotfillPostProcessor and(SlotfillPostProcessor slotfillPostProcessor) {
        if (!(slotfillPostProcessor instanceof HeuristicSlotfillPostProcessor)) {
            return super.and(slotfillPostProcessor);
        }
        final HeuristicSlotfillPostProcessor heuristicSlotfillPostProcessor = (HeuristicSlotfillPostProcessor) slotfillPostProcessor;
        return new HeuristicSlotfillPostProcessor() { // from class: edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor.1
            @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor
            public Maybe<KBPSlotFill> isValidSlotAndRewrite(KBPEntity kBPEntity, KBPSlotFill kBPSlotFill) {
                Maybe<KBPSlotFill> isValidSlotAndRewrite = this.isValidSlotAndRewrite(kBPEntity, kBPSlotFill);
                if (!isValidSlotAndRewrite.isDefined()) {
                    return Maybe.Nothing();
                }
                Maybe<KBPSlotFill> isValidSlotAndRewrite2 = heuristicSlotfillPostProcessor.isValidSlotAndRewrite(kBPEntity, isValidSlotAndRewrite.get());
                return isValidSlotAndRewrite2.isDefined() ? isValidSlotAndRewrite2 : Maybe.Nothing();
            }

            @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor
            public boolean pairwiseKeepLowerScoringFill(KBPEntity kBPEntity, KBPSlotFill kBPSlotFill, KBPSlotFill kBPSlotFill2) {
                return this.pairwiseKeepLowerScoringFill(kBPEntity, kBPSlotFill, kBPSlotFill2) && heuristicSlotfillPostProcessor.pairwiseKeepLowerScoringFill(kBPEntity, kBPSlotFill, kBPSlotFill2);
            }

            @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.HeuristicSlotfillPostProcessor
            public boolean leaveOneOutKeepHeldOutSlot(KBPEntity kBPEntity, IdentityHashSet<KBPSlotFill> identityHashSet, KBPSlotFill kBPSlotFill) {
                return this.leaveOneOutKeepHeldOutSlot(kBPEntity, identityHashSet, kBPSlotFill) && heuristicSlotfillPostProcessor.leaveOneOutKeepHeldOutSlot(kBPEntity, identityHashSet, kBPSlotFill);
            }
        };
    }

    @Override // edu.stanford.nlp.kbp.slotfilling.evaluate.SlotfillPostProcessor
    public List<KBPSlotFill> postProcess(KBPEntity kBPEntity, List<KBPSlotFill> list, GoldResponseSet goldResponseSet) {
        return filterStep(kBPEntity, filterStep(kBPEntity, list, goldResponseSet), goldResponseSet);
    }

    public abstract Maybe<KBPSlotFill> isValidSlotAndRewrite(KBPEntity kBPEntity, KBPSlotFill kBPSlotFill);

    public abstract boolean pairwiseKeepLowerScoringFill(KBPEntity kBPEntity, KBPSlotFill kBPSlotFill, KBPSlotFill kBPSlotFill2);

    public abstract boolean leaveOneOutKeepHeldOutSlot(KBPEntity kBPEntity, IdentityHashSet<KBPSlotFill> identityHashSet, KBPSlotFill kBPSlotFill);

    static {
        $assertionsDisabled = !HeuristicSlotfillPostProcessor.class.desiredAssertionStatus();
    }
}
