/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.collections.ReadOnlyHugeLongIdentityArray;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.Training;
import org.neo4j.gds.ml.TrainingConfig;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.HugeBatchQueue;
import org.neo4j.gds.ml.linkmodels.SignedProbabilities;
import org.neo4j.gds.ml.linkmodels.metrics.LinkMetric;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipeline;
import org.neo4j.gds.ml.linkmodels.pipeline.linkFeatures.LinkFeatureExtractor;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionPredictor;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionTrain;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionTrainConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.train.FeaturesAndTargets;
import org.neo4j.gds.ml.linkmodels.pipeline.train.ImmutableFeaturesAndTargets;
import org.neo4j.gds.ml.linkmodels.pipeline.train.ImmutableModelSelectResult;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.train.ReadOnlyHugeDoubleToLongArrayWrapper;
import org.neo4j.gds.ml.nodemodels.ImmutableModelStats;
import org.neo4j.gds.ml.nodemodels.MetricData;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.model.ModelConfig;
import org.neo4j.gds.utils.StringFormatting;

public class LinkPredictionTrain
extends Algorithm<LinkPredictionTrain, Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo>> {
    public static final String MODEL_TYPE = "Link prediction pipeline";
    private final Graph trainGraph;
    private final Graph validationGraph;
    private final LinkPredictionPipeline pipeline;
    private final LinkPredictionTrainConfig trainConfig;
    private final AllocationTracker allocationTracker;

    public LinkPredictionTrain(Graph trainGraph, Graph validationGraph, LinkPredictionPipeline pipeline, LinkPredictionTrainConfig trainConfig, ProgressTracker progressTracker) {
        super(progressTracker);
        this.trainGraph = trainGraph;
        this.validationGraph = validationGraph;
        this.pipeline = pipeline;
        this.trainConfig = trainConfig;
        this.allocationTracker = AllocationTracker.empty();
    }

    static Task progressTask() {
        return Tasks.task((String)LinkPredictionTrain.class.getSimpleName(), (Task)Tasks.leaf((String)"extract train features"), (Task[])new Task[]{Tasks.leaf((String)"select model"), Training.progressTask((String)"train best model"), Tasks.leaf((String)"compute train metrics"), Tasks.task((String)"evaluate on test data", (Task)Tasks.leaf((String)"extract test features"), (Task[])new Task[]{Tasks.leaf((String)"compute test metrics")})});
    }

    public Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.beginSubTask("extract train features");
        FeaturesAndTargets trainData = this.extractFeaturesAndTargets(this.trainGraph);
        ReadOnlyHugeLongIdentityArray trainRelationshipIds = new ReadOnlyHugeLongIdentityArray(trainData.size());
        this.progressTracker.endSubTask("extract train features");
        this.progressTracker.beginSubTask("select model");
        ModelSelectResult modelSelectResult = this.modelSelect(trainData, (ReadOnlyHugeLongArray)trainRelationshipIds);
        LinkLogisticRegressionTrainConfig bestParameters = modelSelectResult.bestParameters();
        this.progressTracker.endSubTask("select model");
        this.progressTracker.beginSubTask("train best model");
        LinkLogisticRegressionData modelData = this.trainModel((ReadOnlyHugeLongArray)trainRelationshipIds, trainData, bestParameters, this.progressTracker);
        this.progressTracker.endSubTask("train best model");
        this.progressTracker.beginSubTask("compute train metrics");
        Map<LinkMetric, Double> outerTrainMetrics = this.computeTrainMetric(trainData, modelData, (ReadOnlyHugeLongArray)trainRelationshipIds, this.progressTracker);
        this.progressTracker.endSubTask("compute train metrics");
        this.progressTracker.beginSubTask("evaluate on test data");
        Map<LinkMetric, Double> testMetrics = this.computeTestMetric(modelData);
        this.progressTracker.endSubTask("evaluate on test data");
        Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = this.createModel(modelSelectResult, modelData, this.mergeMetrics(modelSelectResult, outerTrainMetrics, testMetrics));
        this.progressTracker.endSubTask();
        return model;
    }

    private FeaturesAndTargets extractFeaturesAndTargets(Graph graph) {
        this.progressTracker.setVolume(graph.relationshipCount() * 2L);
        HugeObjectArray features = LinkFeatureExtractor.extractFeatures((Graph)graph, (List)this.pipeline.featureSteps(), (int)this.trainConfig.concurrency(), (ProgressTracker)this.progressTracker);
        HugeDoubleArray targets = this.extractTargets(graph, features.size());
        return ImmutableFeaturesAndTargets.of((HugeObjectArray<double[]>)features, targets);
    }

    private HugeDoubleArray extractTargets(Graph graph, long numberOfTargets) {
        HugeDoubleArray globalTargets = HugeDoubleArray.newArray((long)numberOfTargets, (AllocationTracker)this.allocationTracker);
        MutableLong relationshipIdx = new MutableLong();
        graph.forEachNode(nodeId -> {
            graph.forEachRelationship(nodeId, -10.0, (src, trg, weight) -> {
                if (weight != 0.0 && weight != 1.0) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Target should be either `1` or `0`. But got %f for relationship (%d, %d)", (Object[])new Object[]{weight, src, trg}));
                }
                globalTargets.set(relationshipIdx.getAndIncrement(), weight);
                return true;
            });
            this.progressTracker.logProgress((long)graph.degree(nodeId));
            return true;
        });
        return globalTargets;
    }

    private ModelSelectResult modelSelect(FeaturesAndTargets trainData, ReadOnlyHugeLongArray trainRelationshipIds) {
        List<TrainingExamplesSplit> validationSplits = this.trainValidationSplits(trainRelationshipIds, trainData.targets());
        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats = this.initStatsMap();
        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats = this.initStatsMap();
        List linkLogisticRegressionTrainConfigs = this.pipeline.trainingParameterSpace();
        this.progressTracker.setVolume((long)linkLogisticRegressionTrainConfigs.size());
        linkLogisticRegressionTrainConfigs.forEach(modelParams -> {
            ModelStatsBuilder trainStatsBuilder = new ModelStatsBuilder((LinkLogisticRegressionTrainConfig)modelParams, this.pipeline.splitConfig().validationFolds());
            ModelStatsBuilder validationStatsBuilder = new ModelStatsBuilder((LinkLogisticRegressionTrainConfig)modelParams, this.pipeline.splitConfig().validationFolds());
            for (TrainingExamplesSplit relSplit : validationSplits) {
                HugeLongArray trainSet = relSplit.trainSet();
                HugeLongArray validationSet = relSplit.testSet();
                LinkLogisticRegressionData modelData = this.trainModel(ReadOnlyHugeLongArray.of((HugeLongArray)trainSet), trainData, (LinkLogisticRegressionTrainConfig)modelParams, ProgressTracker.NULL_TRACKER);
                this.computeTrainMetric(trainData, modelData, ReadOnlyHugeLongArray.of((HugeLongArray)trainSet), ProgressTracker.NULL_TRACKER).forEach(trainStatsBuilder::update);
                this.computeTrainMetric(trainData, modelData, ReadOnlyHugeLongArray.of((HugeLongArray)validationSet), ProgressTracker.NULL_TRACKER).forEach(validationStatsBuilder::update);
            }
            this.trainConfig.metrics().forEach(metric -> {
                ((List)validationStats.get(metric)).add(validationStatsBuilder.modelStats((LinkMetric)metric));
                ((List)trainStats.get(metric)).add(trainStatsBuilder.modelStats((LinkMetric)metric));
            });
            this.progressTracker.logProgress();
        });
        LinkMetric mainMetric = this.trainConfig.metrics().get(0);
        List<ModelStats<LinkLogisticRegressionTrainConfig>> modelStats = validationStats.get(mainMetric);
        ModelStats<LinkLogisticRegressionTrainConfig> winner = Collections.max(modelStats, ModelStats.COMPARE_AVERAGE);
        LinkLogisticRegressionTrainConfig bestConfig = (LinkLogisticRegressionTrainConfig)winner.params();
        return ModelSelectResult.of(bestConfig, trainStats, validationStats);
    }

    private Map<LinkMetric, Double> computeTestMetric(LinkLogisticRegressionData modelData) {
        this.progressTracker.beginSubTask("extract test features");
        FeaturesAndTargets testData = this.extractFeaturesAndTargets(this.validationGraph);
        this.progressTracker.endSubTask("extract test features");
        this.progressTracker.beginSubTask("compute test metrics");
        Map<LinkMetric, Double> result = this.computeMetric(testData, modelData, new BatchQueue(testData.size()), this.progressTracker);
        this.progressTracker.endSubTask("compute test metrics");
        return result;
    }

    private Map<LinkMetric, MetricData<LinkLogisticRegressionTrainConfig>> mergeMetrics(ModelSelectResult modelSelectResult, Map<LinkMetric, Double> outerTrainMetrics, Map<LinkMetric, Double> testMetrics) {
        return modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(Function.identity(), metric -> MetricData.of(modelSelectResult.trainStats().get(metric), modelSelectResult.validationStats().get(metric), (double)((Double)outerTrainMetrics.get(metric)), (double)((Double)testMetrics.get(metric)))));
    }

    private List<TrainingExamplesSplit> trainValidationSplits(ReadOnlyHugeLongArray trainRelationshipIds, HugeDoubleArray actualTargets) {
        StratifiedKFoldSplitter splitter = new StratifiedKFoldSplitter(this.pipeline.splitConfig().validationFolds(), trainRelationshipIds, (ReadOnlyHugeLongArray)new ReadOnlyHugeDoubleToLongArrayWrapper(actualTargets), this.trainConfig.randomSeed());
        return splitter.splits();
    }

    private Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> initStatsMap() {
        HashMap<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> statsMap = new HashMap<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>>();
        statsMap.put(LinkMetric.AUCPR, new ArrayList());
        return statsMap;
    }

    private LinkLogisticRegressionData trainModel(ReadOnlyHugeLongArray trainSet, FeaturesAndTargets trainData, LinkLogisticRegressionTrainConfig llrConfig, ProgressTracker progressTracker) {
        LinkLogisticRegressionTrain llrTrain = new LinkLogisticRegressionTrain(trainSet, trainData.features(), trainData.targets(), llrConfig, progressTracker, this.terminationFlag, this.trainConfig.concurrency());
        return llrTrain.compute();
    }

    private Map<LinkMetric, Double> computeTrainMetric(FeaturesAndTargets trainData, LinkLogisticRegressionData modelData, ReadOnlyHugeLongArray evaluationSet, ProgressTracker progressTracker) {
        return this.computeMetric(trainData, modelData, (BatchQueue)new HugeBatchQueue(evaluationSet), progressTracker);
    }

    private Map<LinkMetric, Double> computeMetric(FeaturesAndTargets inputData, LinkLogisticRegressionData modelData, BatchQueue evaluationQueue, ProgressTracker progressTracker) {
        progressTracker.setVolume(inputData.size());
        LinkLogisticRegressionPredictor predictor = new LinkLogisticRegressionPredictor(modelData);
        SignedProbabilities signedProbabilities = SignedProbabilities.create((long)inputData.size());
        HugeDoubleArray targets = inputData.targets();
        HugeObjectArray<double[]> features = inputData.features();
        evaluationQueue.parallelConsume(this.trainConfig.concurrency(), thread -> batch -> {
            for (Long relationshipIdx : batch.nodeIds()) {
                double predictedProbability = predictor.predictedProbability((double[])features.get(relationshipIdx.longValue()));
                boolean isEdge = targets.get(relationshipIdx.longValue()) == 1.0;
                double signedProbability = isEdge ? predictedProbability : -1.0 * predictedProbability;
                signedProbabilities.add(signedProbability);
            }
            progressTracker.logProgress((long)batch.size());
        }, this.terminationFlag);
        return this.trainConfig.metrics().stream().collect(Collectors.toMap(Function.identity(), metric -> metric.compute(signedProbabilities, this.trainConfig.negativeClassWeight())));
    }

    private Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> createModel(ModelSelectResult modelSelectResult, LinkLogisticRegressionData modelData, Map<LinkMetric, MetricData<LinkLogisticRegressionTrainConfig>> metrics) {
        return Model.of((String)this.trainConfig.username(), (String)this.trainConfig.modelName(), (String)MODEL_TYPE, (GraphSchema)this.trainGraph.schema(), (Object)modelData, (ModelConfig)this.trainConfig, (ToMapConvertible)LinkPredictionModelInfo.of(modelSelectResult.bestParameters(), metrics, this.pipeline.copy()));
    }

    public LinkPredictionTrain me() {
        return this;
    }

    public void release() {
    }

    static class ModelStatsBuilder {
        private final Map<LinkMetric, Double> min;
        private final Map<LinkMetric, Double> max;
        private final Map<LinkMetric, Double> sum;
        private final LinkLogisticRegressionTrainConfig modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(LinkLogisticRegressionTrainConfig modelParams, int numberOfSplits) {
            this.modelParams = modelParams;
            this.numberOfSplits = numberOfSplits;
            this.min = new HashMap<LinkMetric, Double>();
            this.max = new HashMap<LinkMetric, Double>();
            this.sum = new HashMap<LinkMetric, Double>();
        }

        void update(LinkMetric metric, double value) {
            this.min.merge(metric, value, Math::min);
            this.max.merge(metric, value, Math::max);
            this.sum.merge(metric, value, Double::sum);
        }

        ModelStats<LinkLogisticRegressionTrainConfig> modelStats(LinkMetric metric) {
            return ImmutableModelStats.of((TrainingConfig)this.modelParams, (double)(this.sum.get(metric) / (double)this.numberOfSplits), (double)this.min.get(metric), (double)this.max.get(metric));
        }
    }

    @ValueClass
    public static interface ModelSelectResult {
        public LinkLogisticRegressionTrainConfig bestParameters();

        public Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats();

        public Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats();

        public static ModelSelectResult of(LinkLogisticRegressionTrainConfig bestConfig, Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats, Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats) {
            return ImmutableModelSelectResult.of(bestConfig, trainStats, validationStats);
        }
    }
}

