/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.NodePropertyStep;
import org.neo4j.gds.ml.linkmodels.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.linkmodels.pipeline.linkFeatures.LinkFeatureExtractor;
import org.neo4j.gds.ml.linkmodels.pipeline.procedureutils.ProcedureReflection;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.kernel.database.NamedDatabaseId;

public class PipelineExecutor {
    public static final String SPLIT_ERROR_TEMPLATE = "%s graph contains no relationships. Consider increasing the `%s` or provide a larger graph";
    private final TrainingPipeline pipeline;
    private final String userName;
    private final NamedDatabaseId databaseId;
    private final BaseProc caller;
    private final String graphName;
    private final ProgressTracker progressTracker;

    public PipelineExecutor(TrainingPipeline pipeline, BaseProc caller, NamedDatabaseId databaseId, String userName, String graphName, ProgressTracker progressTracker) {
        this.pipeline = pipeline;
        this.caller = caller;
        this.userName = userName;
        this.databaseId = databaseId;
        this.graphName = graphName;
        this.progressTracker = progressTracker;
    }

    public HugeObjectArray<double[]> computeFeatures(Collection<NodeLabel> nodeLabels, RelationshipType relationshipType, int concurrency) {
        Graph graph = GraphStoreCatalog.get((String)this.userName, (NamedDatabaseId)this.databaseId, (String)this.graphName).graphStore().getGraph(nodeLabels, List.of(relationshipType), Optional.empty());
        this.pipeline.validate(graph);
        return LinkFeatureExtractor.extractFeatures((Graph)graph, this.pipeline.featureSteps(), (int)concurrency, (ProgressTracker)this.progressTracker);
    }

    public LinkFeatureExtractor linkFeatureExtractor(Graph graph) {
        return LinkFeatureExtractor.of((Graph)graph, this.pipeline.featureSteps());
    }

    public void executeNodePropertySteps(Collection<NodeLabel> nodeLabels, RelationshipType relationshipType) {
        this.executeNodePropertySteps(nodeLabels, List.of(relationshipType));
    }

    public void executeNodePropertySteps(Collection<NodeLabel> nodeLabels, Collection<RelationshipType> relationshipTypes) {
        this.progressTracker.beginSubTask();
        for (NodePropertyStep step : this.pipeline.nodePropertySteps()) {
            this.progressTracker.beginSubTask();
            step.execute(this.caller, this.graphName, nodeLabels, relationshipTypes);
            this.progressTracker.endSubTask();
        }
        this.progressTracker.endSubTask();
    }

    public void splitRelationships(GraphStore graphStore, List<String> relationshipTypes, List<String> nodeLabels, Optional<Long> randomSeed) {
        this.progressTracker.beginSubTask();
        LinkPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        splitConfig.validateAgainstGraphStore(graphStore);
        String testComplementRelationshipType = splitConfig.testComplementRelationshipType();
        this.relationshipSplit(splitConfig.testSplit(), nodeLabels, relationshipTypes, randomSeed);
        this.validateTestSplit(graphStore);
        this.relationshipSplit(splitConfig.trainSplit(), nodeLabels, List.of(testComplementRelationshipType), randomSeed);
        this.validateTrainSplit(graphStore);
        graphStore.deleteRelationships(RelationshipType.of((String)testComplementRelationshipType));
        this.progressTracker.endSubTask();
    }

    private void validateTestSplit(GraphStore graphStore) {
        String testRelationshipType = this.pipeline.splitConfig().testRelationshipType();
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)testRelationshipType)}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)SPLIT_ERROR_TEMPLATE, (Object[])new Object[]{"Test", "testFraction"}));
        }
    }

    private void validateTrainSplit(GraphStore graphStore) {
        LinkPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)splitConfig.trainRelationshipType())}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)SPLIT_ERROR_TEMPLATE, (Object[])new Object[]{"Train", "trainFraction"}));
        }
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)splitConfig.featureInputRelationshipType())}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)"Feature graph contains no relationships. Consider decreasing %s or %s or provide a larger graph.", (Object[])new Object[]{"testFraction", "trainFraction"}));
        }
    }

    public void removeNodeProperties(GraphStore graphstore, Collection<NodeLabel> nodeLabels) {
        this.pipeline.nodePropertySteps().forEach(step -> {
            Object intermediateProperty = step.config.get("mutateProperty");
            if (intermediateProperty instanceof String) {
                nodeLabels.forEach(label -> graphstore.removeNodeProperty(label, (String)intermediateProperty));
            }
        });
    }

    private void relationshipSplit(SplitRelationshipsBaseConfig splitConfig, final List<String> nodeLabels, final List<String> relationshipTypes, final Optional<Long> randomSeed) {
        HashMap<String, Object> splitRelationshipProcConfig = new HashMap<String, Object>(splitConfig.toSplitMap()){
            {
                super(m);
                this.put("nodeLabels", nodeLabels);
                this.put("relationshipTypes", relationshipTypes);
                randomSeed.ifPresent(seed -> this.put("randomSeed", seed));
            }
        };
        ProcedureReflection procReflection = ProcedureReflection.INSTANCE;
        procReflection.invokeProc(this.caller, this.graphName, procReflection.findProcedureMethod("splitRelationships"), (Map<String, Object>)splitRelationshipProcConfig);
    }
}

