package org.linqs.psl.application.learning.weight.maxlikelihood;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.search.grid.RandomGridSearch;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.grounding.AtomRegisterGroundRuleStore;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTerm;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermGenerator;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermStore;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/maxlikelihood/MaxPseudoLikelihood.class */
public class MaxPseudoLikelihood extends VotedPerceptron {
    public static final String CONFIG_PREFIX = "maxspeudolikelihood";
    public static final String BOOLEAN_KEY = "maxspeudolikelihood.bool";
    public static final boolean BOOLEAN_DEFAULT = false;
    public static final String NUM_SAMPLES_KEY = "maxspeudolikelihood.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 10;
    public static final String MIN_WIDTH_KEY = "maxspeudolikelihood.minwidth";
    public static final double MIN_WIDTH_DEFAULT = 0.01d;
    private final boolean bool;
    private final double minWidth;
    private final int maxNumSamples;
    private int numSamples;

    public MaxPseudoLikelihood(Model model, Database database, Database database2) {
        this(model.getRules(), database, database2);
    }

    public MaxPseudoLikelihood(List<Rule> list, Database database, Database database2) {
        super(list, database, database2, false);
        this.bool = Config.getBoolean(BOOLEAN_KEY, false);
        this.maxNumSamples = Config.getInt(NUM_SAMPLES_KEY, 10);
        this.numSamples = this.maxNumSamples;
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive integer.");
        }
        this.minWidth = Config.getDouble(MIN_WIDTH_KEY, 0.01d);
        if (this.minWidth <= 0.0d) {
            throw new IllegalArgumentException("Minimum width must be positive double.");
        }
        Config.setProperty(WeightLearningApplication.GROUND_RULE_STORE_KEY, AtomRegisterGroundRuleStore.class.getName());
        Config.setProperty(WeightLearningApplication.TERM_STORE_KEY, ConstraintBlockerTermStore.class.getName());
        Config.setProperty(WeightLearningApplication.TERM_GENERATOR_KEY, ConstraintBlockerTermGenerator.class.getName());
        this.cutObjective = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v33 */
    /* JADX WARN: Type inference failed for: r13v0 */
    /* JADX WARN: Type inference failed for: r13v1 */
    /* JADX WARN: Type inference failed for: r13v2 */
    /* JADX WARN: Type inference failed for: r1v25 */
    /* JADX WARN: Type inference failed for: r1v51 */
    /* JADX WARN: Type inference failed for: r2v28 */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void computeExpectedIncompatibility() {
        double[][] dArr;
        if (!(this.termStore instanceof ConstraintBlockerTermStore)) {
            throw new IllegalArgumentException("ConstraintBlockerTermStore required.");
        }
        ConstraintBlockerTermStore constraintBlockerTermStore = (ConstraintBlockerTermStore) this.termStore;
        for (int i = 0; i < this.expectedIncompatibility.length; i++) {
            this.expectedIncompatibility[i] = 0.0d;
        }
        Iterator<ConstraintBlockerTerm> it = constraintBlockerTermStore.iterator();
        while (it.hasNext()) {
            ConstraintBlockerTerm next = it.next();
            if (next.size() != 0) {
                if (this.bool) {
                    dArr = new double[next.getExactlyOne() ? next.size() : next.size() + 1];
                    int i2 = 0;
                    while (true) {
                        if (i2 >= (next.getExactlyOne() ? dArr.length : dArr.length - 1)) {
                            break;
                        }
                        dArr[i2] = new double[next.size()];
                        dArr[i2][i2] = 4607182418800017408;
                        i2++;
                    }
                    if (!next.getExactlyOne()) {
                        dArr[dArr.length - 1] = new double[next.size()];
                    }
                } else {
                    dArr = new double[Math.max(this.numSamples * next.size(), RandomGridSearch.MAX_LOCATIONS_DEFAULT)];
                    SimplexSampler simplexSampler = new SimplexSampler();
                    for (int i3 = 0; i3 < dArr.length; i3++) {
                        dArr[i3] = simplexSampler.getNext(dArr.length);
                    }
                }
                HashMap hashMap = new HashMap();
                float[] fArr = new float[next.size()];
                for (int i4 = 0; i4 < next.size(); i4++) {
                    fArr[i4] = next.getAtoms()[i4].getValue();
                }
                for (WeightedGroundRule weightedGroundRule : next.getIncidentGRs()) {
                    if (weightedGroundRule instanceof WeightedGroundRule) {
                        WeightedRule weightedRule = (WeightedRule) weightedGroundRule.getRule();
                        if (!hashMap.containsKey(weightedRule)) {
                            hashMap.put(weightedRule, new double[dArr.length]);
                        }
                        double[] dArr2 = (double[]) hashMap.get(weightedRule);
                        for (int i5 = 0; i5 < dArr.length; i5++) {
                            for (int i6 = 0; i6 < next.size(); i6++) {
                                next.getAtoms()[i6].setValue((float) dArr[i5][i6]);
                            }
                            int i7 = i5;
                            dArr2[i7] = dArr2[i7] + weightedGroundRule.getIncompatibility();
                        }
                    }
                }
                for (int i8 = 0; i8 < next.size(); i8++) {
                    next.getAtoms()[i8].setValue(fArr[i8]);
                }
                HashMap hashMap2 = new HashMap();
                double d = 0.0d;
                for (int i9 = 0; i9 < dArr.length; i9++) {
                    double d2 = 0.0d;
                    for (Map.Entry entry : hashMap.entrySet()) {
                        d2 -= ((WeightedRule) entry.getKey()).getWeight() * ((double[]) entry.getValue())[i9];
                    }
                    double exp = Math.exp(d2);
                    d += exp;
                    Iterator it2 = hashMap.entrySet().iterator();
                    while (it2.hasNext()) {
                        WeightedRule weightedRule2 = (WeightedRule) ((Map.Entry) it2.next()).getKey();
                        if (!hashMap2.containsKey(weightedRule2)) {
                            hashMap2.put(weightedRule2, Double.valueOf(0.0d));
                        }
                        hashMap2.put(weightedRule2, Double.valueOf(((Double) hashMap2.get(weightedRule2)).doubleValue() + (exp * ((double[]) hashMap.get(weightedRule2))[i9])));
                    }
                }
                for (int i10 = 0; i10 < this.mutableRules.size(); i10++) {
                    WeightedRule weightedRule3 = this.mutableRules.get(i10);
                    if (hashMap2.containsKey(weightedRule3) && ((Double) hashMap2.get(weightedRule3)).doubleValue() > 0.0d) {
                        double[] dArr3 = this.expectedIncompatibility;
                        int i11 = i10;
                        dArr3[i11] = dArr3[i11] + (((Double) hashMap2.get(weightedRule3)).doubleValue() / d);
                    }
                }
            }
        }
    }

    @Override // org.linqs.psl.application.learning.weight.VotedPerceptron, org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void setBudget(double d) {
        super.setBudget(d);
        this.numSamples = (int) Math.ceil(d * this.maxNumSamples);
    }
}
