package org.linqs.psl.reasoner.bool;

import org.linqs.psl.config.Config;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTerm;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermStore;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/reasoner/bool/BooleanMCSat.class */
public class BooleanMCSat implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger(BooleanMCSat.class);
    public static final String CONFIG_PREFIX = "booleanmcsat";
    public static final String NUM_SAMPLES_KEY = "booleanmcsat.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 2500;
    public static final String NUM_BURN_IN_KEY = "booleanmcsat.numburnin";
    public static final int NUM_BURN_IN_DEFAULT = 500;
    private final int numSamples = Config.getInt(NUM_SAMPLES_KEY, 2500);
    private final int numBurnIn;

    public BooleanMCSat() {
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive.");
        }
        this.numBurnIn = Config.getInt(NUM_BURN_IN_KEY, 500);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of burn in samples must be positive.");
        }
        if (this.numBurnIn >= this.numSamples) {
            throw new IllegalArgumentException("Number of burn in samples must be less than number of samples.");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.linqs.psl.reasoner.Reasoner
    public void optimize(TermStore termStore) {
        if (!(termStore instanceof ConstraintBlockerTermStore)) {
            throw new IllegalArgumentException("ConstraintBlockerTermStore required.");
        }
        ConstraintBlockerTermStore constraintBlockerTermStore = (ConstraintBlockerTermStore) termStore;
        constraintBlockerTermStore.randomlyInitialize();
        double[] dArr = new double[constraintBlockerTermStore.size()];
        for (int i = 0; i < constraintBlockerTermStore.size(); i++) {
            dArr[i] = new double[constraintBlockerTermStore.get(i).size()];
        }
        log.info("Beginning inference.");
        for (int i2 = 0; i2 < this.numSamples; i2++) {
            for (int i3 = 0; i3 < constraintBlockerTermStore.size(); i3++) {
                ConstraintBlockerTerm constraintBlockerTerm = constraintBlockerTermStore.get(i3);
                if (constraintBlockerTerm.size() != 0) {
                    double[] dArr2 = new double[constraintBlockerTerm.getExactlyOne() ? constraintBlockerTerm.size() : constraintBlockerTerm.size() + 1];
                    for (int i4 = 0; i4 < dArr2.length; i4++) {
                        for (int i5 = 0; i5 < constraintBlockerTerm.size(); i5++) {
                            if (i5 == i4) {
                                constraintBlockerTerm.getAtoms()[i5].setValue(1.0f);
                            } else {
                                constraintBlockerTerm.getAtoms()[i5].setValue(0.0f);
                            }
                        }
                        dArr2[i4] = computeProbability(constraintBlockerTerm.getIncidentGRs());
                    }
                    double[] sampleWithProbability = sampleWithProbability(dArr2);
                    for (int i6 = 0; i6 < constraintBlockerTerm.getAtoms().length; i6++) {
                        constraintBlockerTerm.getAtoms()[i6].setValue((float) sampleWithProbability[i6]);
                        if (i2 >= this.numBurnIn) {
                            double[] dArr3 = dArr[i3];
                            int i7 = i6;
                            dArr3[i7] = dArr3[i7] + sampleWithProbability[i6];
                        }
                    }
                }
            }
        }
        log.info("Inference complete.");
        for (int i8 = 0; i8 < constraintBlockerTermStore.size(); i8++) {
            for (int i9 = 0; i9 < constraintBlockerTermStore.get(i8).size(); i9++) {
                constraintBlockerTermStore.get(i8).getAtoms()[i9].setValue((float) (dArr[i8][i9] / (this.numSamples - this.numBurnIn)));
            }
        }
    }

    private double computeProbability(WeightedGroundRule[] weightedGroundRuleArr) {
        double d = 0.0d;
        for (WeightedGroundRule weightedGroundRule : weightedGroundRuleArr) {
            d += weightedGroundRule.getWeight() * weightedGroundRule.getIncompatibility();
        }
        return Math.exp((-1.0d) * d);
    }

    private double[] sampleWithProbability(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
        double[] dArr2 = new double[dArr.length];
        double nextDouble = RandUtils.nextDouble();
        double d3 = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d3 += dArr[i3];
            if (d3 >= nextDouble) {
                dArr2[i3] = 1.0d;
                return dArr2;
            }
        }
        dArr2[dArr2.length - 1] = 1.0d;
        return dArr2;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }
}
