/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.ElementIdentifier;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.config.ModelConfig;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.TrainingConfig;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.HugeBatchQueue;
import org.neo4j.gds.ml.linkmodels.SignedProbabilities;
import org.neo4j.gds.ml.linkmodels.metrics.LinkMetric;
import org.neo4j.gds.ml.linkmodels.pipeline.FeaturesAndTargets;
import org.neo4j.gds.ml.linkmodels.pipeline.ImmutableFeaturesAndTargets;
import org.neo4j.gds.ml.linkmodels.pipeline.ImmutableModelSelectResult;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionPredictor;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionTrain;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionTrainConfig;
import org.neo4j.gds.ml.nodemodels.ImmutableModelStats;
import org.neo4j.gds.ml.nodemodels.MetricData;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.splitting.NodeSplit;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.utils.StringFormatting;

public class LinkPredictionTrain
extends Algorithm<LinkPredictionTrain, Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo>> {
    public static final String MODEL_TYPE = "Link prediction pipeline";
    private final GraphStore graphStore;
    private final LinkPredictionTrainConfig trainConfig;
    private final PipelineExecutor pipelineExecutor;
    private final TrainingPipeline pipeline;
    private final AllocationTracker allocationTracker;

    public LinkPredictionTrain(GraphStore graphStore, LinkPredictionTrainConfig trainConfig, TrainingPipeline pipeline, PipelineExecutor pipelineExecutor, ProgressTracker progressTracker) {
        this.graphStore = graphStore;
        this.trainConfig = trainConfig;
        this.pipelineExecutor = pipelineExecutor;
        this.pipeline = pipeline;
        this.progressTracker = progressTracker;
        this.allocationTracker = AllocationTracker.empty();
    }

    public Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> compute() {
        this.progressTracker.beginSubTask();
        List<String> relationshipTypes = this.trainConfig.internalRelationshipTypes(this.graphStore).stream().map(ElementIdentifier::name).collect(Collectors.toList());
        this.pipelineExecutor.splitRelationships(this.graphStore, relationshipTypes, this.trainConfig.nodeLabels(), this.trainConfig.randomSeed());
        this.assertRunning();
        this.pipelineExecutor.executeNodePropertySteps((Collection<NodeLabel>)this.trainConfig.nodeLabelIdentifiers(this.graphStore), RelationshipType.of((String)this.pipeline.splitConfig().featureInputRelationshipType()));
        this.assertRunning();
        FeaturesAndTargets trainData = this.extractFeaturesAndTargets(this.pipeline.splitConfig().trainRelationshipType());
        HugeLongArray trainRelationshipIds = HugeLongArray.newArray((long)trainData.size(), (AllocationTracker)this.allocationTracker);
        trainRelationshipIds.setAll(i -> i);
        ModelSelectResult modelSelectResult = this.modelSelect(trainData, trainRelationshipIds);
        LinkLogisticRegressionTrainConfig bestParameters = modelSelectResult.bestParameters();
        LinkLogisticRegressionData modelData = this.trainModel(trainRelationshipIds, trainData, bestParameters, this.progressTracker);
        Map<LinkMetric, Double> outerTrainMetrics = this.computeTrainMetric(trainData, modelData, trainRelationshipIds, this.progressTracker);
        Map<LinkMetric, Double> testMetrics = this.computeTestMetric(modelData);
        this.cleanUpGraphStore();
        Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = this.createModel(modelSelectResult, modelData, this.mergeMetrics(modelSelectResult, outerTrainMetrics, testMetrics));
        this.progressTracker.endSubTask();
        return model;
    }

    FeaturesAndTargets extractFeaturesAndTargets(String relationshipType) {
        HugeObjectArray<double[]> features = this.pipelineExecutor.computeFeatures(this.trainConfig.nodeLabelIdentifiers(this.graphStore), RelationshipType.of((String)relationshipType), this.trainConfig.concurrency());
        HugeDoubleArray targets = this.extractTargets(features.size(), relationshipType);
        return ImmutableFeaturesAndTargets.of(features, targets);
    }

    public HugeDoubleArray extractTargets(long numberOfTargets, String relationshipType) {
        HugeDoubleArray globalTargets = HugeDoubleArray.newArray((long)numberOfTargets, (AllocationTracker)this.allocationTracker);
        Graph trainGraph = this.graphStore.getGraph(RelationshipType.of((String)relationshipType), Optional.of("label"));
        MutableLong relationshipIdx = new MutableLong();
        trainGraph.forEachNode(nodeId -> {
            trainGraph.forEachRelationship(nodeId, -10.0, (src, trg, weight) -> {
                if (weight != 0.0 && weight != 1.0) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Target should be either `1` or `0`. But got %d for relationship (%d, %d)", (Object[])new Object[]{weight, src, trg}));
                }
                globalTargets.set(relationshipIdx.getAndIncrement(), weight);
                return true;
            });
            return true;
        });
        return globalTargets;
    }

    private ModelSelectResult modelSelect(FeaturesAndTargets trainData, HugeLongArray trainRelationshipIds) {
        this.progressTracker.beginSubTask();
        List<NodeSplit> validationSplits = this.trainValidationSplits(trainRelationshipIds, trainData.targets());
        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats = this.initStatsMap();
        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats = this.initStatsMap();
        this.pipeline.parameterConfigs(this.trainConfig.concurrency()).forEach(modelParams -> {
            ModelStatsBuilder trainStatsBuilder = new ModelStatsBuilder((LinkLogisticRegressionTrainConfig)modelParams, this.pipeline.splitConfig().validationFolds());
            ModelStatsBuilder validationStatsBuilder = new ModelStatsBuilder((LinkLogisticRegressionTrainConfig)modelParams, this.pipeline.splitConfig().validationFolds());
            for (NodeSplit split : validationSplits) {
                HugeLongArray trainSet = split.trainSet();
                HugeLongArray validationSet = split.testSet();
                LinkLogisticRegressionData modelData = this.trainModel(trainSet, trainData, (LinkLogisticRegressionTrainConfig)modelParams, ProgressTracker.NULL_TRACKER);
                this.computeTrainMetric(trainData, modelData, trainSet, ProgressTracker.NULL_TRACKER).forEach(trainStatsBuilder::update);
                this.computeTrainMetric(trainData, modelData, validationSet, ProgressTracker.NULL_TRACKER).forEach(validationStatsBuilder::update);
            }
            this.trainConfig.metrics().forEach(metric -> {
                ((List)validationStats.get(metric)).add(validationStatsBuilder.modelStats((LinkMetric)metric));
                ((List)trainStats.get(metric)).add(trainStatsBuilder.modelStats((LinkMetric)metric));
            });
            this.progressTracker.logProgress();
        });
        LinkMetric mainMetric = this.trainConfig.metrics().get(0);
        List<ModelStats<LinkLogisticRegressionTrainConfig>> modelStats = validationStats.get(mainMetric);
        ModelStats<LinkLogisticRegressionTrainConfig> winner = Collections.max(modelStats, ModelStats.COMPARE_AVERAGE);
        LinkLogisticRegressionTrainConfig bestConfig = (LinkLogisticRegressionTrainConfig)winner.params();
        this.progressTracker.endSubTask();
        return ModelSelectResult.of(bestConfig, trainStats, validationStats);
    }

    private Map<LinkMetric, Double> computeTestMetric(LinkLogisticRegressionData modelData) {
        this.progressTracker.beginSubTask();
        FeaturesAndTargets testData = this.extractFeaturesAndTargets(this.pipeline.splitConfig().testRelationshipType());
        Map<LinkMetric, Double> result = this.computeMetric(testData, modelData, new BatchQueue(testData.size()), this.progressTracker);
        this.progressTracker.endSubTask();
        return result;
    }

    private Map<LinkMetric, MetricData<LinkLogisticRegressionTrainConfig>> mergeMetrics(ModelSelectResult modelSelectResult, Map<LinkMetric, Double> outerTrainMetrics, Map<LinkMetric, Double> testMetrics) {
        return modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(Function.identity(), metric -> MetricData.of(modelSelectResult.trainStats().get(metric), modelSelectResult.validationStats().get(metric), (double)((Double)outerTrainMetrics.get(metric)), (double)((Double)testMetrics.get(metric)))));
    }

    private List<NodeSplit> trainValidationSplits(HugeLongArray trainRelationshipIds, HugeDoubleArray actualTargets) {
        HugeLongArray globalTargets = HugeLongArray.newArray((long)trainRelationshipIds.size(), (AllocationTracker)this.allocationTracker);
        globalTargets.setAll(i -> (long)actualTargets.get(i));
        StratifiedKFoldSplitter splitter = new StratifiedKFoldSplitter(this.pipeline.splitConfig().validationFolds(), trainRelationshipIds, globalTargets, this.trainConfig.randomSeed());
        return splitter.splits();
    }

    private Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> initStatsMap() {
        HashMap<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> statsMap = new HashMap<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>>();
        statsMap.put(LinkMetric.AUCPR, new ArrayList());
        return statsMap;
    }

    private LinkLogisticRegressionData trainModel(HugeLongArray trainSet, FeaturesAndTargets trainData, LinkLogisticRegressionTrainConfig llrConfig, ProgressTracker progressTracker) {
        progressTracker.beginSubTask((long)llrConfig.maxEpochs());
        LinkLogisticRegressionTrain llrTrain = new LinkLogisticRegressionTrain(trainSet, trainData.features(), trainData.targets(), llrConfig, progressTracker, this.terminationFlag);
        LinkLogisticRegressionData modelData = llrTrain.compute();
        progressTracker.endSubTask();
        return modelData;
    }

    private Map<LinkMetric, Double> computeTrainMetric(FeaturesAndTargets trainData, LinkLogisticRegressionData modelData, HugeLongArray evaluationSet, ProgressTracker progressTracker) {
        return this.computeMetric(trainData, modelData, (BatchQueue)new HugeBatchQueue(evaluationSet), progressTracker);
    }

    private Map<LinkMetric, Double> computeMetric(FeaturesAndTargets inputData, LinkLogisticRegressionData modelData, BatchQueue evaluationQueue, ProgressTracker progressTracker) {
        progressTracker.beginSubTask(inputData.size());
        LinkLogisticRegressionPredictor predictor = new LinkLogisticRegressionPredictor(modelData);
        SignedProbabilities signedProbabilities = SignedProbabilities.create((long)inputData.size());
        HugeDoubleArray targets = inputData.targets();
        HugeObjectArray<double[]> features = inputData.features();
        evaluationQueue.parallelConsume(this.trainConfig.concurrency(), thread -> batch -> {
            for (Long relationshipIdx : batch.nodeIds()) {
                double predictedProbability = predictor.predictedProbability((double[])features.get(relationshipIdx.longValue()));
                boolean isEdge = targets.get(relationshipIdx.longValue()) == 1.0;
                double signedProbability = isEdge ? predictedProbability : -1.0 * predictedProbability;
                signedProbabilities.add(signedProbability);
            }
            progressTracker.logProgress((long)batch.size());
        }, this.terminationFlag);
        progressTracker.endSubTask();
        return this.trainConfig.metrics().stream().collect(Collectors.toMap(Function.identity(), metric -> metric.compute(signedProbabilities, this.trainConfig.negativeClassWeight())));
    }

    private Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> createModel(ModelSelectResult modelSelectResult, LinkLogisticRegressionData modelData, Map<LinkMetric, MetricData<LinkLogisticRegressionTrainConfig>> metrics) {
        return Model.of((String)this.trainConfig.username(), (String)this.trainConfig.modelName(), (String)MODEL_TYPE, (GraphSchema)this.graphStore.schema(), (Object)modelData, (ModelConfig)this.trainConfig, (Model.Mappable)LinkPredictionModelInfo.of(modelSelectResult.bestParameters(), metrics, this.pipeline.copy()));
    }

    private void cleanUpGraphStore() {
        LinkPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        List<String> trainRelTypes = List.of(splitConfig.trainRelationshipType(), splitConfig.testRelationshipType(), splitConfig.featureInputRelationshipType());
        trainRelTypes.forEach(relType -> this.graphStore.deleteRelationships(RelationshipType.of((String)relType)));
        this.pipelineExecutor.removeNodeProperties(this.graphStore, this.trainConfig.nodeLabelIdentifiers(this.graphStore));
    }

    public LinkPredictionTrain me() {
        return this;
    }

    public void release() {
    }

    static class ModelStatsBuilder {
        private final Map<LinkMetric, Double> min;
        private final Map<LinkMetric, Double> max;
        private final Map<LinkMetric, Double> sum;
        private final LinkLogisticRegressionTrainConfig modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(LinkLogisticRegressionTrainConfig modelParams, int numberOfSplits) {
            this.modelParams = modelParams;
            this.numberOfSplits = numberOfSplits;
            this.min = new HashMap<LinkMetric, Double>();
            this.max = new HashMap<LinkMetric, Double>();
            this.sum = new HashMap<LinkMetric, Double>();
        }

        void update(LinkMetric metric, double value) {
            this.min.merge(metric, value, Math::min);
            this.max.merge(metric, value, Math::max);
            this.sum.merge(metric, value, Double::sum);
        }

        ModelStats<LinkLogisticRegressionTrainConfig> modelStats(LinkMetric metric) {
            return ImmutableModelStats.of((TrainingConfig)this.modelParams, (double)(this.sum.get(metric) / (double)this.numberOfSplits), (double)this.min.get(metric), (double)this.max.get(metric));
        }
    }

    @ValueClass
    public static interface ModelSelectResult {
        public LinkLogisticRegressionTrainConfig bestParameters();

        public Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats();

        public Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats();

        public static ModelSelectResult of(LinkLogisticRegressionTrainConfig bestConfig, Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats, Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats) {
            return ImmutableModelSelectResult.of(bestConfig, trainStats, validationStats);
        }
    }
}

