package ai.libs.jaicore.ml.dyadranking.algorithm;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.ml.core.exception.ConfigurationException;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.core.predictivemodel.ICertaintyProvider;
import ai.libs.jaicore.ml.core.predictivemodel.IOnlineLearner;
import ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModelConfiguration;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.aeonbits.owner.ConfigFactory;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/algorithm/PLNetDyadRanker.class */
public class PLNetDyadRanker implements IPLDyadRanker, IOnlineLearner<IDyadRankingInstance, IDyadRankingInstance, DyadRankingDataset>, ICertaintyProvider<IDyadRankingInstance, IDyadRankingInstance, DyadRankingDataset> {
    private static final Logger log = LoggerFactory.getLogger(PLNetDyadRanker.class);
    private MultiLayerNetwork plNet;
    private IPLNetDyadRankerConfiguration configuration;
    private int epoch;
    private int iteration;

    public PLNetDyadRanker() {
        this.configuration = ConfigFactory.create(IPLNetDyadRankerConfiguration.class, new Map[0]);
    }

    public PLNetDyadRanker(IPLNetDyadRankerConfiguration iPLNetDyadRankerConfiguration) {
        this.configuration = iPLNetDyadRankerConfiguration;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IBatchLearner
    public void train(DyadRankingDataset dyadRankingDataset) throws TrainingException {
        train(dyadRankingDataset.toND4j());
    }

    public void train(List<INDArray> list) {
        train(list, this.configuration.plNetMaxEpochs(), this.configuration.plNetEarlyStoppingTrainRatio());
        if (this.configuration.plNetEarlyStoppingRetrain()) {
            int i = this.epoch;
            this.plNet = null;
            train(list, i, 1.0d);
        }
    }

    public void train(DyadRankingDataset dyadRankingDataset, int i, double d) {
        train(dyadRankingDataset.toND4j(), i, d);
    }

    public void train(List<INDArray> list, int i, double d) {
        List<INDArray> subList = list.subList(0, (int) (d * list.size()));
        List<INDArray> subList2 = list.subList((int) (d * list.size()), list.size());
        if (this.plNet == null) {
            this.plNet = createNetwork(list.get(0).columns());
            this.plNet.init();
        }
        double d2 = Double.POSITIVE_INFINITY;
        MultiLayerNetwork multiLayerNetwork = this.plNet;
        this.epoch = 0;
        this.iteration = 0;
        int i2 = 0;
        int i3 = 0;
        while (true) {
            if ((i2 < this.configuration.plNetEarlyStoppingPatience() || this.configuration.plNetEarlyStoppingPatience() <= 0) && (this.epoch < i || i == 0)) {
                tryUpdatingWithMinibatch(subList);
                log.debug("plNet params: {}", this.plNet.params());
                i3++;
                if (i3 == this.configuration.plNetEarlyStoppingInterval() && d < 1.0d) {
                    double computeAvgError = computeAvgError(subList2);
                    if (computeAvgError < d2) {
                        d2 = computeAvgError;
                        multiLayerNetwork = this.plNet.clone();
                        log.debug("current best score: {}", Double.valueOf(d2));
                        i2 = 0;
                    } else {
                        i2++;
                    }
                    i3 = 0;
                }
                this.epoch++;
            }
        }
        this.plNet = multiLayerNetwork;
    }

    private void tryUpdatingWithMinibatch(List<INDArray> list) {
        int plNetMiniBatchSize = this.configuration.plNetMiniBatchSize();
        ArrayList arrayList = new ArrayList(plNetMiniBatchSize);
        Iterator<INDArray> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
            if (arrayList.size() == plNetMiniBatchSize) {
                updateWithMinibatch(arrayList);
                arrayList.clear();
            }
        }
        if (arrayList.isEmpty()) {
            return;
        }
        updateWithMinibatch(arrayList);
        arrayList.clear();
    }

    private INDArray computeScaledGradient(INDArray iNDArray) {
        int rows = iNDArray.rows();
        List feedForward = this.plNet.feedForward(iNDArray);
        INDArray transpose = ((INDArray) feedForward.get(feedForward.size() - 1)).transpose();
        INDArray zeros = Nd4j.zeros(new long[]{this.plNet.params().length()});
        MultiLayerNetwork clone = this.plNet.clone();
        for (int i = 0; i < rows; i++) {
            clone.setInput(iNDArray.getRow(i));
            clone.feedForward(true, false);
            Gradient gradient = (Gradient) clone.backpropGradient(PLNetLoss.computeLossGradient(transpose, i), (LayerWorkspaceMgr) null).getFirst();
            this.plNet.getUpdater().update(this.plNet, gradient, this.iteration, this.epoch, 1, LayerWorkspaceMgr.noWorkspaces());
            zeros.addi(gradient.gradient());
        }
        return zeros;
    }

    private INDArray computeScaledGradient(IDyadRankingInstance iDyadRankingInstance) {
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        Iterator<Dyad> it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            arrayList.add(dyadToVector(it.next()));
        }
        List feedForward = this.plNet.feedForward(dyadRankingToMatrix(iDyadRankingInstance));
        INDArray transpose = ((INDArray) feedForward.get(feedForward.size() - 1)).transpose();
        INDArray zeros = Nd4j.zeros(new long[]{this.plNet.params().length()});
        MultiLayerNetwork clone = this.plNet.clone();
        for (int i = 0; i < iDyadRankingInstance.length(); i++) {
            clone.setInput((INDArray) arrayList.get(i));
            clone.feedForward(true, false);
            Gradient gradient = (Gradient) clone.backpropGradient(PLNetLoss.computeLossGradient(transpose, i), (LayerWorkspaceMgr) null).getFirst();
            this.plNet.getUpdater().update(this.plNet, gradient, this.iteration, this.epoch, 1, LayerWorkspaceMgr.noWorkspaces());
            zeros.addi(gradient.gradient());
        }
        return zeros;
    }

    private void updateWithMinibatch(List<INDArray> list) {
        double size = list.size();
        INDArray zeros = Nd4j.zeros(new long[]{this.plNet.params().length()});
        Iterator<INDArray> it = list.iterator();
        while (it.hasNext()) {
            zeros.addi(computeScaledGradient(it.next()));
        }
        zeros.muli(Double.valueOf(1.0d / size));
        this.plNet.params().subi(zeros);
        this.iteration++;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IOnlineLearner
    public void update(IDyadRankingInstance iDyadRankingInstance) throws TrainingException {
        if (this.plNet == null) {
            this.plNet = createNetwork(iDyadRankingInstance.getDyadAtPosition(0).getInstance().length() + iDyadRankingInstance.getDyadAtPosition(0).getAlternative().length());
            this.plNet.init();
        }
        this.plNet.params().subi(computeScaledGradient(iDyadRankingInstance));
        this.iteration++;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IOnlineLearner
    public void update(Set<IDyadRankingInstance> set) throws TrainingException {
        ArrayList arrayList = new ArrayList(set.size());
        for (IDyadRankingInstance iDyadRankingInstance : set) {
            if (this.plNet == null) {
                this.plNet = createNetwork(iDyadRankingInstance.getDyadAtPosition(0).getInstance().length() + iDyadRankingInstance.getDyadAtPosition(0).getAlternative().length());
                this.plNet.init();
            }
            arrayList.add(iDyadRankingInstance.toMatrix());
        }
        updateWithMinibatch(arrayList);
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public IDyadRankingInstance predict(IDyadRankingInstance iDyadRankingInstance) throws PredictionException {
        if (this.plNet == null) {
            this.plNet = createNetwork(iDyadRankingInstance.getDyadAtPosition(0).getInstance().length() + iDyadRankingInstance.getDyadAtPosition(0).getAlternative().length());
            this.plNet.init();
        }
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        for (Dyad dyad : iDyadRankingInstance) {
            arrayList.add(new Pair(dyad, Double.valueOf(this.plNet.output(dyadToVector(dyad)).getDouble(0L))));
        }
        Collections.sort(arrayList, Comparator.comparing(pair -> {
            return Double.valueOf(-((Double) pair.getRight()).doubleValue());
        }));
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(((Pair) it.next()).getLeft());
        }
        return new DyadRankingInstance(arrayList2);
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public List<IDyadRankingInstance> predict(DyadRankingDataset dyadRankingDataset) throws PredictionException {
        ArrayList arrayList = new ArrayList(dyadRankingDataset.size());
        Iterator<IDyadRankingInstance> it = dyadRankingDataset.iterator();
        while (it.hasNext()) {
            arrayList.add(predict(it.next()));
        }
        return arrayList;
    }

    private double computeAvgError(List<INDArray> list) {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        Iterator<INDArray> it = list.iterator();
        while (it.hasNext()) {
            descriptiveStatistics.addValue(PLNetLoss.computeLoss(this.plNet.output(it.next()).transpose()).getDouble(0L));
        }
        return descriptiveStatistics.getMean();
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public void setConfiguration(IPredictiveModelConfiguration iPredictiveModelConfiguration) throws ConfigurationException {
        if (!(iPredictiveModelConfiguration instanceof IPLNetDyadRankerConfiguration)) {
            throw new IllegalArgumentException("The configuration is no PLNetDyadRankerConfiguration!");
        }
        this.configuration = (IPLNetDyadRankerConfiguration) iPredictiveModelConfiguration;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public IPredictiveModelConfiguration getConfiguration() {
        return this.configuration;
    }

    private MultiLayerNetwork createNetwork(int i) {
        if (this.configuration.plNetHiddenNodes().isEmpty()) {
            throw new IllegalArgumentException("There must be at least one hidden layer in specified in the config file!");
        }
        NeuralNetConfiguration.ListBuilder list = new NeuralNetConfiguration.Builder().seed(this.configuration.plNetSeed()).updater(new Adam(this.configuration.plNetLearningRate())).list();
        String plNetActivationFunction = this.configuration.plNetActivationFunction();
        list.layer(0, new DenseLayer.Builder().nIn(i).nOut(this.configuration.plNetHiddenNodes().get(0).intValue()).weightInit(WeightInit.SIGMOID_UNIFORM).activation(Activation.fromString(plNetActivationFunction)).hasBias(true).build());
        List<Integer> plNetHiddenNodes = this.configuration.plNetHiddenNodes();
        for (int i2 = 0; i2 < plNetHiddenNodes.size() - 1; i2++) {
            list.layer(i2 + 1, new DenseLayer.Builder().nIn(plNetHiddenNodes.get(i2).intValue()).nOut(plNetHiddenNodes.get(i2 + 1).intValue()).weightInit(WeightInit.SIGMOID_UNIFORM).activation(Activation.fromString(plNetActivationFunction)).hasBias(true).build());
        }
        list.layer(plNetHiddenNodes.size(), new DenseLayer.Builder().nIn(plNetHiddenNodes.get(plNetHiddenNodes.size() - 1).intValue()).nOut(1).weightInit(WeightInit.UNIFORM).activation(Activation.IDENTITY).hasBias(true).build());
        return new MultiLayerNetwork(list.build());
    }

    private INDArray dyadToVector(Dyad dyad) {
        return Nd4j.hstack(new INDArray[]{Nd4j.create(dyad.getInstance().asArray()), Nd4j.create(dyad.getAlternative().asArray())});
    }

    private INDArray dyadRankingToMatrix(IDyadRankingInstance iDyadRankingInstance) {
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        Iterator<Dyad> it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            arrayList.add(dyadToVector(it.next()));
        }
        return Nd4j.vstack(arrayList);
    }

    public void createNetworkFromDl4jConfigFile(File file) {
        String str = "";
        try {
            str = FileUtil.readFileAsString(file);
        } catch (IOException e) {
            log.error(e.getMessage());
        }
        this.plNet = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(str));
    }

    public void saveModelToFile(String str) throws IOException {
        if (this.plNet == null) {
            throw new IllegalStateException("Cannot save untrained model.");
        }
        ModelSerializer.writeModel(this.plNet, new File(str + ".zip"), true);
    }

    public void loadModelFromFile(String str) throws IOException {
        this.plNet = ModelSerializer.restoreMultiLayerNetwork(str);
    }

    public MultiLayerNetwork getPlNet() {
        return this.plNet;
    }

    public int getEpoch() {
        return this.epoch;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.ICertaintyProvider
    public double getCertainty(IDyadRankingInstance iDyadRankingInstance) {
        if (iDyadRankingInstance.length() != 2) {
            throw new IllegalArgumentException("Can only provide certainty for pairs of dyads!");
        }
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        for (Dyad dyad : iDyadRankingInstance) {
            arrayList.add(new Pair(dyad, Double.valueOf(this.plNet.output(dyadToVector(dyad)).getDouble(0L))));
        }
        return Math.abs(((Double) ((Pair) arrayList.get(0)).getRight()).doubleValue() - ((Double) ((Pair) arrayList.get(1)).getRight()).doubleValue());
    }

    public IDyadRankingInstance getPairWithLeastCertainty(IDyadRankingInstance iDyadRankingInstance) {
        if (this.plNet == null) {
            this.plNet = createNetwork(iDyadRankingInstance.getDyadAtPosition(0).getInstance().length() + iDyadRankingInstance.getDyadAtPosition(0).getAlternative().length());
            this.plNet.init();
        }
        if (iDyadRankingInstance.length() < 2) {
            throw new IllegalArgumentException("The query instance must contain at least 2 dyads!");
        }
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        for (Dyad dyad : iDyadRankingInstance) {
            arrayList.add(new Pair(dyad, Double.valueOf(this.plNet.output(dyadToVector(dyad)).getDouble(0L))));
        }
        Collections.sort(arrayList, Comparator.comparing(pair -> {
            return Double.valueOf(-((Double) pair.getRight()).doubleValue());
        }));
        int i = 0;
        double d = Double.MAX_VALUE;
        for (int i2 = 0; i2 < arrayList.size() - 1; i2++) {
            double abs = Math.abs(((Double) ((Pair) arrayList.get(i2)).getRight()).doubleValue() - ((Double) ((Pair) arrayList.get(i2 + 1)).getRight()).doubleValue());
            if (abs < d) {
                d = abs;
                i = i2;
            }
        }
        LinkedList linkedList = new LinkedList();
        linkedList.add(((Pair) arrayList.get(i)).getLeft());
        linkedList.add(((Pair) arrayList.get(i + 1)).getLeft());
        return new DyadRankingInstance(linkedList);
    }

    public double getProbabilityOfTopRanking(IDyadRankingInstance iDyadRankingInstance) {
        return getProbabilityOfTopKRanking(iDyadRankingInstance, iDyadRankingInstance.length());
    }

    private List<Pair<Dyad, Double>> getDyadUtilityPairsForInstance(IDyadRankingInstance iDyadRankingInstance) {
        if (this.plNet == null) {
            this.plNet = createNetwork(iDyadRankingInstance.getDyadAtPosition(0).getInstance().length() + iDyadRankingInstance.getDyadAtPosition(0).getAlternative().length());
            this.plNet.init();
        }
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        for (Dyad dyad : iDyadRankingInstance) {
            arrayList.add(new Pair(dyad, Double.valueOf(this.plNet.output(dyadToVector(dyad)).getDouble(0L))));
        }
        return arrayList;
    }

    private List<Pair<Dyad, Double>> getSortedDyadUtilityPairsForInstance(IDyadRankingInstance iDyadRankingInstance) {
        List<Pair<Dyad, Double>> dyadUtilityPairsForInstance = getDyadUtilityPairsForInstance(iDyadRankingInstance);
        Collections.sort(dyadUtilityPairsForInstance, Comparator.comparing(pair -> {
            return Double.valueOf(-((Double) pair.getRight()).doubleValue());
        }));
        return dyadUtilityPairsForInstance;
    }

    public double getProbabilityOfTopKRanking(IDyadRankingInstance iDyadRankingInstance, int i) {
        List<Pair<Dyad, Double>> sortedDyadUtilityPairsForInstance = getSortedDyadUtilityPairsForInstance(iDyadRankingInstance);
        double d = 1.0d;
        for (int i2 = 0; i2 < Integer.min(i, sortedDyadUtilityPairsForInstance.size()); i2++) {
            double d2 = 0.0d;
            for (int i3 = i2; i3 < Integer.min(i, sortedDyadUtilityPairsForInstance.size()); i3++) {
                d2 += Math.exp(((Double) sortedDyadUtilityPairsForInstance.get(i3).getRight()).doubleValue());
            }
            d = d2 != 0.0d ? d * (Math.exp(((Double) sortedDyadUtilityPairsForInstance.get(i2).getRight()).doubleValue()) / d2) : Double.NaN;
        }
        return d;
    }

    public double getLogProbabilityOfTopRanking(IDyadRankingInstance iDyadRankingInstance) {
        return getLogProbabilityOfTopKRanking(iDyadRankingInstance, Integer.MAX_VALUE);
    }

    public double getLogProbabilityOfTopKRanking(IDyadRankingInstance iDyadRankingInstance, int i) {
        List<Pair<Dyad, Double>> sortedDyadUtilityPairsForInstance = getSortedDyadUtilityPairsForInstance(iDyadRankingInstance);
        double d = 0.0d;
        for (int i2 = 0; i2 < Integer.min(i, sortedDyadUtilityPairsForInstance.size()); i2++) {
            double d2 = 0.0d;
            for (int i3 = i2; i3 < Integer.min(i, sortedDyadUtilityPairsForInstance.size()); i3++) {
                d2 += Math.exp(((Double) sortedDyadUtilityPairsForInstance.get(i3).getRight()).doubleValue());
            }
            d += ((Double) sortedDyadUtilityPairsForInstance.get(i2).getRight()).doubleValue() - Math.log(d2);
        }
        return d;
    }

    public double getProbabilityRanking(IDyadRankingInstance iDyadRankingInstance) {
        List<Pair<Dyad, Double>> dyadUtilityPairsForInstance = getDyadUtilityPairsForInstance(iDyadRankingInstance);
        double d = 1.0d;
        for (int i = 0; i < dyadUtilityPairsForInstance.size(); i++) {
            double d2 = 0.0d;
            for (int i2 = i; i2 < dyadUtilityPairsForInstance.size(); i2++) {
                d2 += Math.exp(((Double) dyadUtilityPairsForInstance.get(i2).getRight()).doubleValue());
            }
            d = d2 != 0.0d ? d * (Math.exp(((Double) dyadUtilityPairsForInstance.get(i).getRight()).doubleValue()) / d2) : Double.NaN;
        }
        return d;
    }

    public double getLogProbabilityRanking(IDyadRankingInstance iDyadRankingInstance) {
        List<Pair<Dyad, Double>> dyadUtilityPairsForInstance = getDyadUtilityPairsForInstance(iDyadRankingInstance);
        double d = 0.0d;
        for (int i = 0; i < dyadUtilityPairsForInstance.size(); i++) {
            double d2 = 0.0d;
            for (int i2 = i; i2 < dyadUtilityPairsForInstance.size(); i2++) {
                d2 += ((Double) dyadUtilityPairsForInstance.get(i2).getRight()).doubleValue();
            }
            d += ((Double) dyadUtilityPairsForInstance.get(i).getRight()).doubleValue() - d2;
        }
        return d;
    }

    public double getSkillForDyad(Dyad dyad) {
        if (this.plNet == null) {
            return Double.NaN;
        }
        return this.plNet.output(dyadToVector(dyad)).getDouble(0L);
    }
}
