/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.neo4j.gds.ml.linkmodels;

import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionPredictor;
import org.neo4j.graphalgo.AbstractAlgorithmFactory;
import org.neo4j.graphalgo.AlgorithmFactory;
import org.neo4j.graphalgo.MutateProc;
import org.neo4j.graphalgo.Orientation;
import org.neo4j.graphalgo.RelationshipType;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.config.GraphCreateConfig;
import org.neo4j.graphalgo.core.Aggregation;
import org.neo4j.graphalgo.core.CypherMapWrapper;
import org.neo4j.graphalgo.core.concurrency.Pools;
import org.neo4j.graphalgo.core.loading.construction.GraphFactory;
import org.neo4j.graphalgo.core.model.ModelCatalog;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.ProgressTimer;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.exceptions.MemoryEstimationNotImplementedException;
import org.neo4j.graphalgo.result.AbstractResultBuilder;
import org.neo4j.graphalgo.results.StandardMutateResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.values.storable.NumberType;

import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

import static org.neo4j.graphalgo.config.GraphCreateConfigValidations.validateIsUndirectedGraph;

public class LinkPredictionPredictMutateProc extends MutateProc<LinkPredictionPredict, LinkPredictionResult, LinkPredictionPredictMutateProc.MutateResult, LinkPredictionPredictMutateConfig> {

    @Procedure(name = "gds.alpha.ml.linkPrediction.predict.mutate", mode = Mode.READ)
    @Description("Predicts relationships for all node pairs based on a previously trained link prediction model")
    public Stream<MutateResult> mutate(
        @Name(value = "graphName") Object graphNameOrConfig,
        @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
    ) {
        var result = compute(graphNameOrConfig, configuration);
        return mutate(result);
    }

    @Override
    protected void validateConfigs(
        GraphCreateConfig graphCreateConfig, LinkPredictionPredictMutateConfig config
    ) {
        validateIsUndirectedGraph(graphCreateConfig, config);
    }

    @Override
    protected LinkPredictionPredictMutateConfig newConfig(
        String username,
        Optional<String> graphName,
        Optional<GraphCreateConfig> maybeImplicitCreate,
        CypherMapWrapper config
    ) {
        return LinkPredictionPredictMutateConfig.of(
            username,
            graphName,
            maybeImplicitCreate,
            config
        );
    }

    @Override
    protected AlgorithmFactory<LinkPredictionPredict, LinkPredictionPredictMutateConfig> algorithmFactory() {
        return new AbstractAlgorithmFactory<>() {

            @Override
            protected long taskVolume(Graph graph, LinkPredictionPredictMutateConfig configuration) {
                return graph.nodeCount() * graph.nodeCount();
            }

            @Override
            protected String taskName() {
                return "LinkPredictionPredict";
            }

            @Override
            protected LinkPredictionPredict build(
                Graph graph,
                LinkPredictionPredictMutateConfig configuration,
                AllocationTracker tracker,
                ProgressLogger progressLogger
            ) {
                var model = ModelCatalog.get(
                    configuration.username(),
                    configuration.modelName(),
                    LinkLogisticRegressionData.class,
                    LinkPredictionTrainConfig.class
                );

                return new LinkPredictionPredict(
                    new LinkLogisticRegressionPredictor(model.data()),
                    graph,
                    configuration.batchSize(),
                    configuration.concurrency(),
                    configuration.topN(),
                    tracker,
                    progressLogger,
                    configuration.threshold()
                );
            }

            @Override
            public MemoryEstimation memoryEstimation(LinkPredictionPredictMutateConfig configuration) {
                throw new MemoryEstimationNotImplementedException();
            }
        };
    }

    @Override
    protected AbstractResultBuilder<MutateResult> resultBuilder(
        ComputationResult<LinkPredictionPredict, LinkPredictionResult, LinkPredictionPredictMutateConfig> computeResult
    ) {
        return new MutateResult.Builder();
    }

    @Override
    protected void updateGraphStore(
        AbstractResultBuilder<?> resultBuilder,
        ComputationResult<LinkPredictionPredict, LinkPredictionResult, LinkPredictionPredictMutateConfig> computationResult
    ) {
        var relationshipsBuilder = GraphFactory.initRelationshipsBuilder()
            .aggregation(Aggregation.SINGLE)
            .nodes(computationResult.graph().nodeMapping())
            .orientation(Orientation.UNDIRECTED)
            .loadRelationshipProperty(true)
            .concurrency(1)
            .executorService(Pools.DEFAULT)
            .tracker(AllocationTracker.empty())
            .build();

        computationResult
            .result()
            .stream()
            .forEach(predictedLink -> relationshipsBuilder.addFromInternal(predictedLink.sourceId(),
                predictedLink.targetId(),
                predictedLink.probability()
            ));
        var relationships = relationshipsBuilder.build();

        var config = computationResult.config();
        try (ProgressTimer ignored = ProgressTimer.start(resultBuilder::withMutateMillis)) {
            computationResult.graphStore().addRelationshipType(
                RelationshipType.of(config.mutateRelationshipType()),
                Optional.of(config.mutateProperty()),
                Optional.of(NumberType.FLOATING_POINT),
                relationships
            );
        }
        resultBuilder.withRelationshipsWritten(relationships.topology().elementCount());
    }

    @SuppressWarnings("unused")
    public static final class MutateResult extends StandardMutateResult {

        public final long relationshipsWritten;

        MutateResult(
            long createMillis,
            long computeMillis,
            long mutateMillis,
            long relationshipsWritten,
            Map<String, Object> configuration
        ) {
            super(
                createMillis,
                computeMillis,
                0L,
                mutateMillis,
                configuration
            );
            this.relationshipsWritten = relationshipsWritten;
        }

        static class Builder extends AbstractResultBuilder<LinkPredictionPredictMutateProc.MutateResult> {

            @Override
            public LinkPredictionPredictMutateProc.MutateResult build() {
                return new LinkPredictionPredictMutateProc.MutateResult(
                    createMillis,
                    computeMillis,
                    mutateMillis,
                    relationshipsWritten,
                    config.toMap()
                );
            }
        }
    }
}
