/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.nodemodels.BestMetricData;
import org.neo4j.gds.ml.nodemodels.BestModelStats;
import org.neo4j.gds.ml.nodemodels.ImmutableModelSelectResult;
import org.neo4j.gds.ml.nodemodels.Metric;
import org.neo4j.gds.ml.nodemodels.MetricData;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.nodemodels.NodeClassificationModelInfo;
import org.neo4j.gds.ml.nodemodels.NodeClassificationTrain;
import org.neo4j.gds.ml.nodemodels.NodeClassificationTrainConfig;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionTrainConfig;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionTrainCoreConfig;
import org.neo4j.gds.ml.nodemodels.pipeline.ImmutableNodeClassificationPipelineTrainResult;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineTrainResult;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.model.ModelConfig;

public class NodeClassificationTrainPipelineExecutor
extends PipelineExecutor<NodeClassificationPipelineTrainConfig, NodeClassificationPipeline, NodeClassificationPipelineTrainResult> {
    public static final String MODEL_TYPE = "Node classification pipeline";

    public NodeClassificationTrainPipelineExecutor(NodeClassificationPipeline pipeline, NodeClassificationPipelineTrainConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)((NodeClassificationPipelineTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((NodeClassificationPipelineTrainConfig)this.config).internalRelationshipTypes(this.graphStore)));
    }

    protected NodeClassificationPipelineTrainResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        Collection nodeLabels = ((NodeClassificationPipelineTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore);
        Collection relationshipTypes = ((NodeClassificationPipelineTrainConfig)this.config).internalRelationshipTypes(this.graphStore);
        Graph graph = this.graphStore.getGraph(nodeLabels, relationshipTypes, Optional.empty());
        Model innerModel = NodeClassificationTrain.create((Graph)graph, (NodeClassificationTrainConfig)this.innerConfig(), (AllocationTracker)this.executionContext.allocationTracker(), (ProgressTracker)this.progressTracker).compute();
        NodeClassificationModelInfo innerInfo = (NodeClassificationModelInfo)innerModel.customInfo();
        Map<Metric, BestMetricData> bestMetrics = innerInfo.metrics().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, metricStats -> this.extractWinningModelStats((MetricData<NodeLogisticRegressionTrainConfig>)((MetricData)metricStats.getValue()), innerInfo.bestParameters())));
        NodeClassificationPipelineModelInfo modelInfo = NodeClassificationPipelineModelInfo.builder().classes(innerInfo.classes()).bestParameters(innerInfo.bestParameters()).metrics(bestMetrics).trainingPipeline(((NodeClassificationPipeline)this.pipeline).copy()).build();
        return ImmutableNodeClassificationPipelineTrainResult.of((Model<NodeLogisticRegressionData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo>)Model.of((String)innerModel.creator(), (String)innerModel.name(), (String)MODEL_TYPE, (GraphSchema)innerModel.graphSchema(), (Object)((NodeLogisticRegressionData)innerModel.data()), (ModelConfig)((NodeClassificationPipelineTrainConfig)this.config), (ToMapConvertible)modelInfo), ImmutableModelSelectResult.of((NodeLogisticRegressionTrainConfig)innerInfo.bestParameters(), this.getTrainingStats(innerInfo), this.getValidationStats(innerInfo)));
    }

    private Map<Metric, List<ModelStats<NodeLogisticRegressionTrainConfig>>> getTrainingStats(NodeClassificationModelInfo innerInfo) {
        return innerInfo.metrics().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, metricStats -> ((MetricData)metricStats.getValue()).train()));
    }

    private Map<Metric, List<ModelStats<NodeLogisticRegressionTrainConfig>>> getValidationStats(NodeClassificationModelInfo innerInfo) {
        return innerInfo.metrics().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, metricStats -> ((MetricData)metricStats.getValue()).validation()));
    }

    private BestMetricData extractWinningModelStats(MetricData<NodeLogisticRegressionTrainConfig> oldStats, NodeLogisticRegressionTrainConfig bestParams) {
        return BestMetricData.of((BestModelStats)NodeClassificationTrainPipelineExecutor.findBestModelStats(oldStats.train(), bestParams), (BestModelStats)NodeClassificationTrainPipelineExecutor.findBestModelStats(oldStats.validation(), bestParams), (double)oldStats.outerTrain(), (double)oldStats.test());
    }

    NodeClassificationTrainConfig innerConfig() {
        List params = ((NodeClassificationPipeline)this.pipeline).trainingParameterSpace().stream().map(NodeLogisticRegressionTrainCoreConfig::toMap).collect(Collectors.toList());
        return NodeClassificationTrainConfig.builder().modelName(((NodeClassificationPipelineTrainConfig)this.config).modelName()).concurrency(((NodeClassificationPipelineTrainConfig)this.config).concurrency()).username(((NodeClassificationPipelineTrainConfig)this.config).username()).metrics((Iterable)((NodeClassificationPipelineTrainConfig)this.config).metrics()).targetProperty(((NodeClassificationPipelineTrainConfig)this.config).targetProperty()).featureProperties(((NodeClassificationPipeline)this.pipeline).featureProperties()).params(params).randomSeed(((NodeClassificationPipelineTrainConfig)this.config).randomSeed()).holdoutFraction(((NodeClassificationPipeline)this.pipeline).splitConfig().testFraction()).validationFolds(((NodeClassificationPipeline)this.pipeline).splitConfig().validationFolds()).nodeLabels((Iterable)((NodeClassificationPipelineTrainConfig)this.config).nodeLabels()).relationshipTypes((Iterable)((NodeClassificationPipelineTrainConfig)this.config).relationshipTypes()).minBatchSize(((NodeClassificationPipelineTrainConfig)this.config).minBatchSize()).build();
    }

    private static BestModelStats findBestModelStats(List<ModelStats<NodeLogisticRegressionTrainConfig>> metricStatsForModels, NodeLogisticRegressionTrainConfig bestParams) {
        return metricStatsForModels.stream().filter(metricStatsForModel -> metricStatsForModel.params() == bestParams).findFirst().map(BestModelStats::of).orElseThrow();
    }
}

