package org.linqs.psl.application.learning.weight;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.evaluation.statistics.ContinuousEvaluator;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.grounding.MemoryGroundRuleStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.misc.GroundValueConstraint;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.ADMMReasoner;
import org.linqs.psl.reasoner.admm.term.ADMMTermGenerator;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/WeightLearningApplication.class */
public abstract class WeightLearningApplication implements ModelApplication {
    public static final String CONFIG_PREFIX = "weightlearning";
    public static final String REASONER_KEY = "weightlearning.reasoner";
    public static final String GROUND_RULE_STORE_KEY = "weightlearning.groundrulestore";
    public static final String TERM_STORE_KEY = "weightlearning.termstore";
    public static final String TERM_GENERATOR_KEY = "weightlearning.termgenerator";
    public static final String EVALUATOR_KEY = "weightlearning.evaluator";
    public static final String RANDOM_WEIGHTS_KEY = "weightlearning.randomweights";
    public static final boolean RANDOM_WEIGHTS_DEFAULT = false;
    public static final int MAX_RANDOM_WEIGHT = 100;
    public static final int MIN_ADMM_STEPS = 3;
    protected boolean supportsLatentVariables;
    protected Database rvDB;
    protected Database observedDB;
    protected PersistedAtomManager atomManager;
    protected List<Rule> allRules = new ArrayList();
    protected List<WeightedRule> mutableRules = new ArrayList();
    protected double[] observedIncompatibility;
    protected double[] expectedIncompatibility;
    protected TrainingMap trainingMap;
    protected Reasoner reasoner;
    protected GroundRuleStore groundRuleStore;
    protected GroundRuleStore latentGroundRuleStore;
    protected TermGenerator termGenerator;
    protected TermStore termStore;
    protected TermStore latentTermStore;
    protected Evaluator evaluator;
    private boolean groundModelInit;
    protected boolean inMPEState;
    protected boolean inLatentMPEState;
    private static final Logger log = LoggerFactory.getLogger(WeightLearningApplication.class);
    public static final String REASONER_DEFAULT = ADMMReasoner.class.getName();
    public static final String GROUND_RULE_STORE_DEFAULT = MemoryGroundRuleStore.class.getName();
    public static final String TERM_STORE_DEFAULT = ADMMTermStore.class.getName();
    public static final String TERM_GENERATOR_DEFAULT = ADMMTermGenerator.class.getName();
    public static final String EVALUATOR_DEFAULT = ContinuousEvaluator.class.getName();

    public WeightLearningApplication(List<Rule> list, Database database, Database database2, boolean z) {
        this.rvDB = database;
        this.observedDB = database2;
        this.supportsLatentVariables = z;
        for (Rule rule : list) {
            this.allRules.add(rule);
            if (rule instanceof WeightedRule) {
                this.mutableRules.add((WeightedRule) rule);
            }
        }
        this.observedIncompatibility = new double[this.mutableRules.size()];
        this.expectedIncompatibility = new double[this.mutableRules.size()];
        this.groundModelInit = false;
        this.inMPEState = false;
        this.inLatentMPEState = false;
        this.evaluator = (Evaluator) Config.getNewObject(EVALUATOR_KEY, EVALUATOR_DEFAULT);
    }

    public void learn() {
        initGroundModel();
        if (this.supportsLatentVariables) {
            initLatentGroundModel();
        }
        doLearn();
    }

    protected abstract void doLearn();

    public void setBudget(double d) {
        if (this.reasoner instanceof ADMMReasoner) {
            ((ADMMReasoner) this.reasoner).setMaxIter(Math.max(3, (int) Math.ceil(Config.getInt(ADMMReasoner.MAX_ITER_KEY, ADMMReasoner.MAX_ITER_DEFAULT) * d)));
            if (this.termStore instanceof ADMMTermStore) {
                ((ADMMTermStore) this.termStore).resetLocalVairables();
            }
        }
    }

    public GroundRuleStore getGroundRuleStore() {
        return this.groundRuleStore;
    }

    protected void initGroundModel() {
        if (this.groundModelInit) {
            return;
        }
        PersistedAtomManager createAtomManager = createAtomManager();
        ensureTargets(createAtomManager);
        GroundRuleStore groundRuleStore = (GroundRuleStore) Config.getNewObject(GROUND_RULE_STORE_KEY, GROUND_RULE_STORE_DEFAULT);
        log.info("Grounding out model.");
        Grounding.groundAll(this.allRules, createAtomManager, groundRuleStore);
        initGroundModel(createAtomManager, groundRuleStore);
    }

    public void initGroundModel(GroundRuleStore groundRuleStore) {
        if (this.groundModelInit) {
            return;
        }
        initGroundModel(createAtomManager(), groundRuleStore);
    }

    private void initGroundModel(PersistedAtomManager persistedAtomManager, GroundRuleStore groundRuleStore) {
        if (this.groundModelInit) {
            return;
        }
        TermStore termStore = (TermStore) Config.getNewObject(TERM_STORE_KEY, TERM_STORE_DEFAULT);
        TermGenerator termGenerator = (TermGenerator) Config.getNewObject(TERM_GENERATOR_KEY, TERM_GENERATOR_DEFAULT);
        log.debug("Initializing objective terms for {} ground rules.", Integer.valueOf(groundRuleStore.size()));
        termStore.ensureVariableCapacity(persistedAtomManager.getCachedRVACount());
        log.debug("Generated {} objective terms from {} ground rules.", Integer.valueOf(termGenerator.generateTerms(groundRuleStore, termStore)), Integer.valueOf(groundRuleStore.size()));
        TrainingMap trainingMap = new TrainingMap(persistedAtomManager, this.observedDB, false);
        if (this.supportsLatentVariables || trainingMap.getLatentVariables().size() <= 0) {
            initGroundModel((Reasoner) Config.getNewObject(REASONER_KEY, REASONER_DEFAULT), groundRuleStore, termStore, termGenerator, persistedAtomManager, trainingMap);
        } else {
            Set<RandomVariableAtom> latentVariables = trainingMap.getLatentVariables();
            throw new IllegalArgumentException(String.format("All RandomVariableAtoms must have corresponding ObservedAtoms, found %d latent variables. Latent variables are not supported by this WeightLearningApplication (%s). Example latent variable: [%s].", Integer.valueOf(latentVariables.size()), getClass().getName(), latentVariables.iterator().next()));
        }
    }

    public void initGroundModel(Reasoner reasoner, GroundRuleStore groundRuleStore, TermStore termStore, TermGenerator termGenerator, PersistedAtomManager persistedAtomManager, TrainingMap trainingMap) {
        if (this.groundModelInit) {
            return;
        }
        this.reasoner = reasoner;
        this.groundRuleStore = groundRuleStore;
        this.termStore = termStore;
        this.termGenerator = termGenerator;
        this.atomManager = persistedAtomManager;
        this.trainingMap = trainingMap;
        if (Config.getBoolean(RANDOM_WEIGHTS_KEY, false)) {
            initRandomWeights();
        }
        postInitGroundModel();
        this.groundModelInit = true;
    }

    private void initRandomWeights() {
        log.trace("Randomly Weighted Rules:");
        for (WeightedRule weightedRule : this.mutableRules) {
            weightedRule.setWeight(RandUtils.nextInt(100) + 1);
            log.trace("    " + weightedRule.toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void postInitGroundModel() {
    }

    protected void initLatentGroundModel() {
        this.latentGroundRuleStore = (GroundRuleStore) Config.getNewObject(GROUND_RULE_STORE_KEY, GROUND_RULE_STORE_DEFAULT);
        this.latentTermStore = (TermStore) Config.getNewObject(TERM_STORE_KEY, TERM_STORE_DEFAULT);
        log.info("Grounding out latent model.");
        int groundAll = Grounding.groundAll(this.allRules, this.atomManager, this.latentGroundRuleStore);
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            this.latentGroundRuleStore.addGroundRule(new GroundValueConstraint(entry.getKey(), entry.getValue().getValue()));
        }
        int size = groundAll + this.trainingMap.getTrainingMap().size();
        log.debug("Initializing latent objective terms for {} ground rules.", Integer.valueOf(size));
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        log.debug("Generated {} latent objective terms from {} ground rules.", Integer.valueOf(this.termGenerator.generateTerms(this.latentGroundRuleStore, this.latentTermStore)), Integer.valueOf(size));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeMPEState() {
        if (this.inMPEState) {
            return;
        }
        this.termStore.clear();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        this.reasoner.optimize(this.termStore);
        this.inMPEState = true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeLatentMPEState() {
        if (this.inLatentMPEState) {
            return;
        }
        this.termStore.clear();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        this.reasoner.optimize(this.latentTermStore);
        this.inLatentMPEState = true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeObservedIncompatibility() {
        setLabeledRandomVariables();
        for (int i = 0; i < this.observedIncompatibility.length; i++) {
            this.observedIncompatibility[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i2))) {
                double[] dArr = this.observedIncompatibility;
                int i3 = i2;
                dArr[i3] = dArr[i3] + ((WeightedGroundRule) groundRule).getIncompatibility();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeExpectedIncompatibility() {
        computeMPEState();
        for (int i = 0; i < this.expectedIncompatibility.length; i++) {
            this.expectedIncompatibility[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i2))) {
                double[] dArr = this.expectedIncompatibility;
                int i3 = i2;
                dArr[i3] = dArr[i3] + ((WeightedGroundRule) groundRule).getIncompatibility();
            }
        }
    }

    public double computeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.mutableRules.size(); i++) {
            d += this.mutableRules.get(i).getWeight() * (this.observedIncompatibility[i] - this.expectedIncompatibility[i]);
        }
        return d;
    }

    @Override // org.linqs.psl.application.ModelApplication
    public void close() {
        if (this.groundRuleStore != null) {
            this.groundRuleStore.close();
            this.groundRuleStore = null;
        }
        if (this.latentGroundRuleStore != null) {
            this.latentGroundRuleStore.close();
            this.latentGroundRuleStore = null;
        }
        if (this.termStore != null) {
            this.termStore.close();
            this.termStore = null;
        }
        if (this.latentTermStore != null) {
            this.latentTermStore.close();
            this.latentTermStore = null;
        }
        if (this.reasoner != null) {
            this.reasoner.close();
            this.reasoner = null;
        }
        this.termGenerator = null;
        this.trainingMap = null;
        this.atomManager = null;
        this.rvDB = null;
        this.observedDB = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLabeledRandomVariables() {
        this.inMPEState = false;
        this.inLatentMPEState = false;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setDefaultRandomVariables() {
        this.inMPEState = false;
        this.inLatentMPEState = false;
        Iterator<RandomVariableAtom> it = this.trainingMap.getTrainingMap().keySet().iterator();
        while (it.hasNext()) {
            it.next().setValue(0.0f);
        }
        Iterator<RandomVariableAtom> it2 = this.trainingMap.getLatentVariables().iterator();
        while (it2.hasNext()) {
            it2.next().setValue(0.0f);
        }
    }

    protected PersistedAtomManager createAtomManager() {
        return new PersistedAtomManager(this.rvDB);
    }

    private void ensureTargets(PersistedAtomManager persistedAtomManager) {
        for (StandardPredicate standardPredicate : this.observedDB.getDataStore().getRegisteredPredicates()) {
            if (!this.observedDB.isClosed(standardPredicate)) {
                for (ObservedAtom observedAtom : this.observedDB.getAllGroundObservedAtoms(standardPredicate)) {
                    GroundAtom atom = persistedAtomManager.getAtom(observedAtom.getPredicate(), observedAtom.getArguments());
                    if (!(atom instanceof ObservedAtom)) {
                        ((RandomVariableAtom) atom).setValue(0.0f);
                    }
                }
            }
        }
        persistedAtomManager.commitPersistedAtoms();
    }

    public static WeightLearningApplication getWLA(String str, List<Rule> list, Database database, Database database2) {
        String resolveClassName = Reflection.resolveClassName(str);
        if (resolveClassName == null) {
            throw new IllegalArgumentException("Could not find class: " + str);
        }
        try {
            try {
                try {
                    return (WeightLearningApplication) Class.forName(resolveClassName).getConstructor(List.class, Database.class, Database.class).newInstance(list, database, database2);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("Insufficient access to constructor for " + resolveClassName, e);
                } catch (InstantiationException e2) {
                    throw new RuntimeException("Unable to instantiate weight learner (" + resolveClassName + ")", e2);
                } catch (InvocationTargetException e3) {
                    throw new RuntimeException("Error thrown while constructing " + resolveClassName, e3);
                }
            } catch (NoSuchMethodException e4) {
                throw new IllegalArgumentException("No sutible constructor found for weight learner: " + resolveClassName + ".", e4);
            }
        } catch (ClassNotFoundException e5) {
            throw new IllegalArgumentException("Could not find class: " + resolveClassName, e5);
        }
    }
}
