/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.MutateComputationResultConsumer;
import org.neo4j.gds.MutateProc;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.GraphProjectConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.utils.ProgressTimer;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.validation.AfterLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.ml.splitting.SplitRelationships;
import org.neo4j.gds.ml.splitting.SplitRelationshipsAlgorithmFactory;
import org.neo4j.gds.ml.splitting.SplitRelationshipsMutateConfig;
import org.neo4j.gds.result.AbstractResultBuilder;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.values.storable.NumberType;

@GdsCallable(name="gds.alpha.ml.splitRelationships.mutate", description="Splits a graph into holdout and remaining relationship types and adds them to the graph.", executionMode=ExecutionMode.MUTATE_RELATIONSHIP)
public class SplitRelationshipsMutateProc
extends MutateProc<SplitRelationships, EdgeSplitter.SplitResult, MutateResult, SplitRelationshipsMutateConfig> {
    @Procedure(name="gds.alpha.ml.splitRelationships.mutate", mode=Mode.READ)
    @Description(value="Splits a graph into holdout and remaining relationship types and adds them to the graph.")
    public Stream<MutateResult> mutate(@Name(value="graphName") String graphName, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        ComputationResult computationResult = this.compute(graphName, configuration);
        return this.mutate(computationResult);
    }

    protected SplitRelationshipsMutateConfig newConfig(String username, CypherMapWrapper config) {
        return SplitRelationshipsMutateConfig.of((CypherMapWrapper)config);
    }

    public GraphStoreAlgorithmFactory<SplitRelationships, SplitRelationshipsMutateConfig> algorithmFactory() {
        return new SplitRelationshipsAlgorithmFactory();
    }

    protected AbstractResultBuilder<MutateResult> resultBuilder(ComputationResult<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig> computeResult, ExecutionContext executionContext) {
        return new MutateResult.Builder();
    }

    public MutateComputationResultConsumer<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig, MutateResult> computationResultConsumer() {
        return new MutateComputationResultConsumer<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig, MutateResult>(this::resultBuilder){

            protected void updateGraphStore(AbstractResultBuilder<?> resultBuilder, ComputationResult<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig> computationResult, ExecutionContext executionContext) {
                SplitRelationshipsMutateConfig config = (SplitRelationshipsMutateConfig)computationResult.config();
                try (ProgressTimer ignored = ProgressTimer.start(arg_0 -> resultBuilder.withMutateMillis(arg_0));){
                    computationResult.graphStore().addRelationshipType(config.remainingRelationshipType(), Optional.ofNullable(config.relationshipWeightProperty()), Optional.of(NumberType.FLOATING_POINT), ((EdgeSplitter.SplitResult)computationResult.result()).remainingRels());
                    computationResult.graphStore().addRelationshipType(config.holdoutRelationshipType(), Optional.of("label"), Optional.of(NumberType.INTEGRAL), ((EdgeSplitter.SplitResult)computationResult.result()).selectedRels());
                }
                long holdoutWritten = ((EdgeSplitter.SplitResult)computationResult.result()).selectedRels().topology().elementCount();
                long remainingWritten = ((EdgeSplitter.SplitResult)computationResult.result()).remainingRels().topology().elementCount();
                resultBuilder.withRelationshipsWritten(holdoutWritten + remainingWritten);
            }
        };
    }

    public ValidationConfiguration<SplitRelationshipsMutateConfig> validationConfig() {
        return new ValidationConfiguration<SplitRelationshipsMutateConfig>(){

            public List<AfterLoadValidation<SplitRelationshipsMutateConfig>> afterLoadValidations() {
                return List.of(new Validation());
            }
        };
    }

    static class Validation
    implements AfterLoadValidation<SplitRelationshipsMutateConfig> {
        Validation() {
        }

        public void validateConfigsAfterLoad(GraphStore graphStore, GraphProjectConfig graphProjectConfig, SplitRelationshipsMutateConfig config) {
            this.validateTypeDoesNotExist(graphStore, config.holdoutRelationshipType());
            this.validateTypeDoesNotExist(graphStore, config.remainingRelationshipType());
            this.validateNonNegativeRelationshipTypesExist(graphStore, config);
        }

        private void validateNonNegativeRelationshipTypesExist(GraphStore graphStore, SplitRelationshipsMutateConfig config) {
            config.nonNegativeRelationshipTypes().forEach(relationshipType -> {
                if (!graphStore.hasRelationshipType(RelationshipType.of((String)relationshipType))) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Relationship type `%s` does not exist in the in-memory graph.", (Object[])new Object[]{relationshipType}));
                }
            });
        }

        private void validateTypeDoesNotExist(GraphStore graphStore, RelationshipType holdoutRelationshipType) {
            if (graphStore.hasRelationshipType(holdoutRelationshipType)) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Relationship type `%s` already exists in the in-memory graph.", (Object[])new Object[]{holdoutRelationshipType.name()}));
            }
        }
    }

    public static class MutateResult {
        public final long preProcessingMillis;
        public final long computeMillis;
        public final long mutateMillis;
        public final long relationshipsWritten;
        public final Map<String, Object> configuration;

        MutateResult(long preProcessingMillis, long computeMillis, long mutateMillis, long relationshipsWritten, Map<String, Object> configuration) {
            this.preProcessingMillis = preProcessingMillis;
            this.computeMillis = computeMillis;
            this.mutateMillis = mutateMillis;
            this.relationshipsWritten = relationshipsWritten;
            this.configuration = configuration;
        }

        static class Builder
        extends AbstractResultBuilder<MutateResult> {
            Builder() {
            }

            public MutateResult build() {
                return new MutateResult(this.preProcessingMillis, this.computeMillis, this.mutateMillis, this.relationshipsWritten, this.config.toMap());
            }
        }
    }
}

