package org.linqs.psl.application.learning.weight.search;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/search/Hyperband.class */
public class Hyperband extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(Hyperband.class);
    public static final String CONFIG_PREFIX = "hyperband";
    public static final String SURVIVAL_KEY = "hyperband.survival";
    public static final int SURVIVAL_DEFAULT = 4;
    public static final String BASE_BRACKET_SIZE_KEY = "hyperband.basebracketsize";
    public static final int BASE_BRACKET_SIZE_DEFAULT = 10;
    public static final String NUM_BRACKETS_KEY = "hyperband.numbrackets";
    public static final int NUM_BRACKETS_DEFAULT = 4;
    public static final double MIN_BUDGET_PROPORTION = 0.001d;
    public static final int MIN_BRACKET_SIZE = 1;
    public static final double MEAN = 0.5d;
    public static final double VARIANCE = 0.1d;
    private final int survival;
    private double bestObjective;
    private double[] bestWeights;
    private int numBrackets;
    private int baseBracketSize;

    /* loaded from: input_file:org/linqs/psl/application/learning/weight/search/Hyperband$RunResult.class */
    private static class RunResult implements Comparable<RunResult> {
        public double[] weights;
        public double objective;

        public RunResult(double[] dArr, double d) {
            this.weights = dArr;
            this.objective = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(RunResult runResult) {
            return Double.compare(this.objective, runResult.objective);
        }
    }

    public Hyperband(Model model, Database database, Database database2) {
        this(model.getRules(), database, database2);
    }

    public Hyperband(List<Rule> list, Database database, Database database2) {
        super(list, database, database2, false);
        this.survival = Config.getInt(SURVIVAL_KEY, 4);
        if (this.survival < 1) {
            throw new IllegalArgumentException("Need at least one survival porportion.");
        }
        this.numBrackets = Config.getInt(NUM_BRACKETS_KEY, 4);
        if (this.numBrackets < 1) {
            throw new IllegalArgumentException("Need at least one bracket.");
        }
        this.baseBracketSize = Config.getInt(BASE_BRACKET_SIZE_KEY, 10);
        if (this.baseBracketSize < 1) {
            throw new IllegalArgumentException("Need at least one bracket size.");
        }
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    protected void doLearn() {
        double d = -1.0d;
        double[] dArr = null;
        computeObservedIncompatibility();
        double d2 = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.numBrackets; i2++) {
            double pow = Math.pow(this.survival, i2) / (i2 + 1);
            int max = (int) Math.max(1.0d, Math.ceil(pow * this.baseBracketSize));
            i += max;
            double pow2 = Math.pow(this.survival, (-1.0d) * i2);
            log.debug("Bracket {} / {} -- Size: {} ({}), Budget: {}", new Object[]{Integer.valueOf(i2 + 1), Integer.valueOf(this.numBrackets), Integer.valueOf(max), Double.valueOf(pow), Double.valueOf(pow2)});
            List<double[]> chooseConfigs = chooseConfigs(max);
            for (int i3 = 0; i3 <= i2; i3++) {
                int size = chooseConfigs.size();
                double pow3 = pow2 * Math.pow(this.survival, i3);
                setBudget(Math.max(0.001d, Math.min(1.0d, pow3)));
                log.debug("  Round {} / {} -- Size: {}, Budget: {}", new Object[]{Integer.valueOf(i3 + 1), Integer.valueOf(i2 + 1), Integer.valueOf(size), Double.valueOf(pow3)});
                PriorityQueue priorityQueue = new PriorityQueue();
                for (double[] dArr2 : chooseConfigs) {
                    d2 += pow3;
                    for (int i4 = 0; i4 < this.mutableRules.size(); i4++) {
                        this.mutableRules.get(i4).setWeight(dArr2[i4]);
                    }
                    this.inMPEState = false;
                    this.inLatentMPEState = false;
                    double run = run(dArr2);
                    priorityQueue.add(new RunResult(dArr2, run));
                    if (dArr == null || run < d) {
                        d = run;
                        dArr = dArr2;
                    }
                    log.debug("Training Objective: {}, Weights: {}", Double.valueOf(run), dArr2);
                }
                chooseConfigs.clear();
                for (int i5 = 0; i5 < ((int) Math.floor(size / this.survival)); i5++) {
                    chooseConfigs.add(((RunResult) priorityQueue.poll()).weights);
                }
            }
        }
        for (int i6 = 0; i6 < this.mutableRules.size(); i6++) {
            this.mutableRules.get(i6).setWeight(dArr[i6]);
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
        log.debug("Hyperband complete. Configurations examined: {}. Total budget: {}", Integer.valueOf(i), Double.valueOf(d2));
    }

    private List<double[]> chooseConfigs(int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr = new double[this.mutableRules.size()];
            for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                dArr[i3] = (RandUtils.nextDouble() * Math.sqrt(0.1d)) + 0.5d;
            }
            arrayList.add(dArr);
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double run(double[] dArr) {
        setDefaultRandomVariables();
        computeExpectedIncompatibility();
        this.evaluator.compute(this.trainingMap);
        double representativeMetric = this.evaluator.getRepresentativeMetric();
        return this.evaluator.isHigherRepresentativeBetter() ? (-1.0d) * representativeMetric : representativeMetric;
    }
}
