package edu.stanford.nlp.kbp.slotfilling.evaluate.inference;

import edu.stanford.nlp.kbp.common.Pointer;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.ArrayIterable;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/inference/BayesNet.class */
public class BayesNet<E> extends AbstractSet<Factor> {
    protected final E[] predicates;
    protected final Factor[] factors;
    protected final Collection<Factor>[] factorsByPredicate;
    protected final Index<E> index;
    private final Map<Integer, Boolean> initialValues;
    private final boolean[] isFixed;
    private final double[] priors;
    protected final int[] adjustable;
    private final boolean doHillclimb;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/inference/BayesNet$AssignmentState.class */
    private class AssignmentState {
        public final double RESTART_ITERS = 100000.0d;
        public boolean doMAP;
        public boolean doMarginal;
        private final Random rand;
        public boolean[] assignment;
        public double[] counts;
        public long[] lastUpdate;
        public double logScore;
        public long numIters;
        public boolean[] bestAssignment;
        public double bestLogScore;
        static final /* synthetic */ boolean $assertionsDisabled;

        public AssignmentState(int i) {
            this.RESTART_ITERS = 100000.0d;
            this.numIters = 0L;
            this.bestLogScore = Double.NEGATIVE_INFINITY;
            this.rand = new Random(i);
            this.assignment = new boolean[BayesNet.this.predicates.length];
            for (Map.Entry entry : BayesNet.this.initialValues.entrySet()) {
                this.assignment[((Integer) entry.getKey()).intValue()] = ((Boolean) entry.getValue()).booleanValue();
            }
            this.counts = new double[BayesNet.this.predicates.length];
            this.lastUpdate = new long[BayesNet.this.predicates.length];
            this.bestAssignment = new boolean[this.assignment.length];
            randomRestart();
            this.logScore = computeLogScore();
            if (!$assertionsDisabled && SloppyMath.isVeryDangerous(this.logScore)) {
                throw new AssertionError();
            }
        }

        private AssignmentState(int i, boolean[] zArr) {
            this.RESTART_ITERS = 100000.0d;
            this.numIters = 0L;
            this.bestLogScore = Double.NEGATIVE_INFINITY;
            this.rand = new Random(i);
            this.assignment = zArr;
            this.counts = new double[BayesNet.this.predicates.length];
            this.bestAssignment = new boolean[zArr.length];
            this.logScore = computeLogScore();
            if (!$assertionsDisabled && SloppyMath.isVeryDangerous(this.logScore)) {
                throw new AssertionError();
            }
        }

        protected void randomRestart() {
            if (this.doMarginal) {
                updateCounts();
            }
            for (int i = 0; i < this.assignment.length; i++) {
                if (!BayesNet.this.isFixed[i]) {
                    this.assignment[i] = this.rand.nextDouble() < BayesNet.this.priors[i];
                }
            }
            this.logScore = computeLogScore();
        }

        protected void updateCounts() {
            if (this.numIters == 0) {
                return;
            }
            for (int i = 0; i < this.assignment.length; i++) {
                double[] dArr = this.counts;
                int i2 = i;
                dArr[i2] = dArr[i2] + ((((this.assignment[i] ? 1.0d : 0.0d) - this.counts[i]) * (this.numIters - this.lastUpdate[i])) / this.numIters);
                this.lastUpdate[i] = this.numIters;
            }
        }

        public void gibbsStep() {
            if (this.numIters % 100000.0d == 0.0d) {
                randomRestart();
            }
            if (this.doMAP && this.numIters % 10000 == 0) {
                if (!$assertionsDisabled && !Double.isInfinite(this.logScore) && Math.abs(this.logScore - computeLogScore()) >= 0.1d) {
                    throw new AssertionError();
                }
                if (BayesNet.this.doHillclimb) {
                    for (int i = 0; i < BayesNet.this.adjustable.length; i++) {
                        gibbsStep(BayesNet.this.adjustable[i], true);
                    }
                }
                if (!$assertionsDisabled && !Double.isInfinite(this.logScore) && Math.abs(this.logScore - computeLogScore()) >= 0.1d) {
                    throw new AssertionError();
                }
                this.logScore = computeLogScore();
            }
            int i2 = BayesNet.this.adjustable[this.rand.nextInt(BayesNet.this.adjustable.length)];
            gibbsStep(i2, false);
            this.numIters++;
            if (this.doMarginal) {
                double[] dArr = this.counts;
                dArr[i2] = dArr[i2] + ((((this.assignment[i2] ? 1.0d : 0.0d) - this.counts[i2]) * (this.numIters - this.lastUpdate[i2])) / this.numIters);
            }
        }

        public void gibbsStep(int i, boolean z) {
            double d;
            double d2;
            double exp;
            if (!$assertionsDisabled && BayesNet.this.isFixed[i]) {
                throw new AssertionError();
            }
            double d3 = this.logScore;
            if (this.assignment[i]) {
                d2 = d3;
                for (Factor factor : BayesNet.this.factorsByPredicate[i]) {
                    this.assignment[i] = true;
                    double logProb = d3 - factor.logProb(this.assignment);
                    this.assignment[i] = false;
                    d3 = logProb + factor.logProb(this.assignment);
                }
                d = d3;
            } else {
                d = d3;
                for (Factor factor2 : BayesNet.this.factorsByPredicate[i]) {
                    this.assignment[i] = false;
                    double logProb2 = d3 - factor2.logProb(this.assignment);
                    this.assignment[i] = true;
                    d3 = logProb2 + factor2.logProb(this.assignment);
                }
                d2 = d3;
            }
            if (Double.isInfinite(d2) && Double.isInfinite(d)) {
                exp = 0.5d;
            } else if (Double.isInfinite(d2)) {
                exp = 1.0d;
            } else if (Double.isInfinite(d)) {
                exp = 0.0d;
            } else {
                exp = Math.exp(d2 - SloppyMath.logAdd(d, d2));
            }
            if (!$assertionsDisabled && (exp < -0.001d || exp > 1.001d)) {
                throw new AssertionError();
            }
            if (exp < 0.0d) {
                exp = 0.0d;
            }
            if (exp > 1.0d) {
                exp = 1.0d;
            }
            if (z) {
                if (exp > 0.5d) {
                    this.assignment[i] = true;
                    this.logScore = d2;
                } else {
                    this.assignment[i] = false;
                    this.logScore = d;
                }
            } else if (this.rand.nextDouble() < exp) {
                this.assignment[i] = true;
                this.logScore = d2;
            } else {
                this.assignment[i] = false;
                this.logScore = d;
            }
            if (!this.doMAP || this.logScore <= this.bestLogScore) {
                return;
            }
            System.arraycopy(this.assignment, 0, this.bestAssignment, 0, this.assignment.length);
            this.bestLogScore = d3;
        }

        public double computeLogScore() {
            double d = 0.0d;
            for (Factor factor : BayesNet.this.factors) {
                d += factor.logProb(this.assignment);
            }
            return d;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append(this.logScore).append(" ");
            for (int i = 0; i < this.assignment.length; i++) {
                if (!this.assignment[i]) {
                    sb.append('!');
                }
                sb.append(BayesNet.this.predicates[i]).append(", ");
            }
            return sb.toString();
        }

        static {
            $assertionsDisabled = !BayesNet.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/evaluate/inference/BayesNet$Factor.class */
    public interface Factor {
        String getName();

        double logProb(boolean[] zArr);

        Collection<Integer> components();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BayesNet(Index<E> index, E[] eArr, Factor[] factorArr, Map<Integer, Double> map, Map<Integer, Boolean> map2, boolean z) {
        this.index = index;
        this.predicates = eArr;
        this.factors = factorArr;
        this.priors = new double[eArr.length];
        for (int i = 0; i < eArr.length; i++) {
            this.priors[i] = map.containsKey(Integer.valueOf(i)) ? Math.exp(map.get(Integer.valueOf(i)).doubleValue()) : 0.2d;
        }
        this.initialValues = map2;
        this.isFixed = new boolean[eArr.length];
        Iterator<Map.Entry<Integer, Boolean>> it = map2.entrySet().iterator();
        while (it.hasNext()) {
            this.isFixed[it.next().getKey().intValue()] = true;
        }
        this.adjustable = new int[eArr.length - map2.size()];
        int i2 = 0;
        for (int i3 = 0; i3 < this.isFixed.length; i3++) {
            if (!this.isFixed[i3]) {
                int i4 = i2;
                i2++;
                this.adjustable[i4] = i3;
            }
        }
        if (!$assertionsDisabled && i2 != this.adjustable.length) {
            throw new AssertionError();
        }
        this.doHillclimb = z;
        this.factorsByPredicate = new Collection[eArr.length];
        for (int i5 = 0; i5 < this.factorsByPredicate.length; i5++) {
            this.factorsByPredicate[i5] = new ArrayList();
        }
        for (Factor factor : factorArr) {
            Iterator<Integer> it2 = factor.components().iterator();
            while (it2.hasNext()) {
                this.factorsByPredicate[it2.next().intValue()].add(factor);
            }
        }
    }

    @Override // java.util.AbstractCollection, java.util.Collection, java.lang.Iterable, java.util.Set
    public Iterator<Factor> iterator() {
        return new ArrayIterable(this.factors).iterator();
    }

    @Override // java.util.AbstractCollection, java.util.Collection, java.util.Set
    public int size() {
        return this.factors.length;
    }

    public int variableCount() {
        return this.predicates.length;
    }

    public Counter<E> gibbsMLE(int i) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (Map.Entry entry : gibbsMarginals(i).entrySet()) {
            if (((Double) entry.getValue()).doubleValue() > 0.5d) {
                classicCounter.setCount(entry.getKey(), ((Double) entry.getValue()).doubleValue());
            }
        }
        return classicCounter;
    }

    public Counter<E> gibbsMAP(int i) {
        if (this.adjustable.length == 0) {
            Redwood.Util.log(new Object[]{"Warning: no-non-fixed predicates in BayesNet;"});
            ClassicCounter classicCounter = new ClassicCounter();
            for (Map.Entry<Integer, Boolean> entry : this.initialValues.entrySet()) {
                classicCounter.setCount(this.predicates[entry.getKey().intValue()], entry.getValue().booleanValue() ? 1.0d : 0.0d);
            }
            return classicCounter;
        }
        ArrayList arrayList = new ArrayList();
        boolean[] zArr = new boolean[this.predicates.length];
        Pointer pointer = new Pointer(Double.valueOf(Double.NEGATIVE_INFINITY));
        for (int i2 = 0; i2 < Execution.threads; i2++) {
            int i3 = i2;
            arrayList.add(() -> {
                AssignmentState assignmentState = new AssignmentState(i3);
                assignmentState.doMAP = true;
                for (int i4 = 0; i4 < i; i4++) {
                    assignmentState.gibbsStep();
                }
                synchronized (pointer) {
                    if (((Double) pointer.dereference().get()).doubleValue() < assignmentState.bestLogScore) {
                        pointer.set((Pointer) Double.valueOf(assignmentState.bestLogScore));
                        System.arraycopy(assignmentState.bestAssignment, 0, zArr, 0, assignmentState.bestAssignment.length);
                    }
                }
            });
        }
        Redwood.Util.threadAndRun("Gibbs Sampling", arrayList, arrayList.size());
        Redwood.Util.log(new Object[]{"Best assignment had log score: " + pointer.dereference().getOrElse(Double.valueOf(-1.0d))});
        ClassicCounter classicCounter2 = new ClassicCounter();
        for (int i4 = 0; i4 < this.predicates.length; i4++) {
            if (zArr[i4]) {
                classicCounter2.setCount(this.predicates[i4], Double.POSITIVE_INFINITY);
            }
        }
        return classicCounter2;
    }

    public Counter<E> gibbsMarginals(int i) {
        if (this.adjustable.length == 0) {
            Redwood.Util.log(new Object[]{"Warning: no-non-fixed predicates in BayesNet;"});
            ClassicCounter classicCounter = new ClassicCounter();
            for (Map.Entry<Integer, Boolean> entry : this.initialValues.entrySet()) {
                classicCounter.setCount(this.predicates[entry.getKey().intValue()], entry.getValue().booleanValue() ? 1.0d : 0.0d);
            }
            return classicCounter;
        }
        ArrayList arrayList = new ArrayList();
        ClassicCounter classicCounter2 = new ClassicCounter();
        Pointer pointer = new Pointer(Double.valueOf(Double.NEGATIVE_INFINITY));
        for (int i2 = 0; i2 < Execution.threads; i2++) {
            int i3 = i2;
            arrayList.add(() -> {
                AssignmentState assignmentState = new AssignmentState(i3);
                assignmentState.doMarginal = true;
                for (int i4 = 0; i4 < i; i4++) {
                    assignmentState.gibbsStep();
                }
                assignmentState.updateCounts();
                synchronized (pointer) {
                    if (((Double) pointer.dereference().get()).doubleValue() < assignmentState.bestLogScore) {
                        pointer.set((Pointer) Double.valueOf(assignmentState.bestLogScore));
                    }
                    for (int i5 = 0; i5 < this.predicates.length; i5++) {
                        classicCounter2.incrementCount(this.predicates[i5], assignmentState.counts[i5] / Execution.threads);
                    }
                }
            });
        }
        Redwood.Util.threadAndRun("Gibbs Sampling", arrayList, arrayList.size());
        return classicCounter2;
    }

    public double logProb(Set<E> set) {
        boolean[] zArr = new boolean[this.predicates.length];
        for (int i = 0; i < zArr.length; i++) {
            zArr[i] = set.contains(this.predicates[i]);
        }
        return new AssignmentState(42, zArr).computeLogScore();
    }

    static {
        $assertionsDisabled = !BayesNet.class.desiredAssertionStatus();
    }
}
