/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.ToLongFunction;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.DegreePartition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStep;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.FeaturesAndLabels;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.ImmutableFeaturesAndLabels;
import org.neo4j.gds.utils.StringFormatting;

final class LinkFeaturesAndLabelsExtractor {
    private LinkFeaturesAndLabelsExtractor() {
    }

    static MemoryEstimation estimate(MemoryRange fudgedLinkFeatureDim, ToLongFunction<Map<RelationshipType, Long>> relSetSizeExtractor, String setDesc) {
        return MemoryEstimations.builder().rangePerGraphDimension(setDesc + " relationship features", (graphDim, threads) -> fudgedLinkFeatureDim.apply(MemoryUsage::sizeOfDoubleArray).times(relSetSizeExtractor.applyAsLong(graphDim.relationshipCounts())).add(MemoryUsage.sizeOfInstance(HugeObjectArray.class))).perGraphDimension(setDesc + "relationship targets", (graphDim, threads) -> MemoryRange.of((long)HugeIntArray.memoryEstimation((long)relSetSizeExtractor.applyAsLong(graphDim.relationshipCounts())))).build();
    }

    static FeaturesAndLabels extractFeaturesAndLabels(Graph graph, List<LinkFeatureStep> featureSteps, int concurrency, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        progressTracker.setSteps(graph.relationshipCount() * 2L);
        Features features = LinkFeatureExtractor.extractFeatures(graph, featureSteps, concurrency, progressTracker, terminationFlag);
        HugeIntArray labels = LinkFeaturesAndLabelsExtractor.extractLabels(graph, features.size(), concurrency, progressTracker, terminationFlag);
        return ImmutableFeaturesAndLabels.of(features, labels);
    }

    private static HugeIntArray extractLabels(Graph graph, long numberOfTargets, int concurrency, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        HugeIntArray globalLabels = HugeIntArray.newArray((long)numberOfTargets);
        List partitions = PartitionUtils.degreePartition((Graph)graph, (int)concurrency, Function.identity(), Optional.of(100));
        ArrayList<Runnable> tasks = new ArrayList<Runnable>();
        MutableLong relationshipOffset = new MutableLong();
        for (DegreePartition partition : partitions) {
            Long startRelationshipOffset = relationshipOffset.getValue();
            tasks.add(() -> {
                MutableLong currentRelationshipOffset = new MutableLong((Number)startRelationshipOffset);
                Graph localGraph = graph.concurrentCopy();
                partition.consume(nodeId -> localGraph.forEachRelationship(nodeId, -10.0, (src, trg, weight) -> {
                    if (weight != 0.0 && weight != 1.0) {
                        throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Label should be either `1` or `0`. But got %f for relationship (%d, %d)", (Object[])new Object[]{weight, src, trg}));
                    }
                    globalLabels.set(currentRelationshipOffset.getAndIncrement(), (int)weight);
                    return true;
                }));
                progressTracker.logSteps(partition.totalDegree());
            });
            relationshipOffset.add(partition.totalDegree());
        }
        RunWithConcurrency.builder().concurrency(concurrency).tasks(tasks).terminationFlag(terminationFlag).run();
        return globalLabels;
    }
}

