/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodePropertyPrediction;

import java.util.Optional;
import java.util.SplittableRandom;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeMergeSort;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.shuffle.ShuffleUtil;
import org.neo4j.gds.ml.nodePropertyPrediction.ImmutableNodeSplits;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.util.TrainingSetWarnings;

public final class NodeSplitter {
    private final int concurrency;
    private final long numberOfExamples;
    private final ProgressTracker progressTracker;
    private final LongUnaryOperator toOriginalId;
    private final LongUnaryOperator toMappedId;

    public NodeSplitter(int concurrency, long numberOfExamples, ProgressTracker progressTracker, LongUnaryOperator toOriginalId, LongUnaryOperator toMappedId) {
        this.concurrency = concurrency;
        this.numberOfExamples = numberOfExamples;
        this.progressTracker = progressTracker;
        this.toOriginalId = toOriginalId;
        this.toMappedId = toMappedId;
    }

    public NodeSplits split(double testFraction, int validationFolds, Optional<Long> randomSeed) {
        HugeLongArray allTrainingExamples = HugeLongArray.newArray((long)this.numberOfExamples);
        allTrainingExamples.setAll(this.toOriginalId);
        HugeMergeSort.sort((HugeLongArray)allTrainingExamples, (int)this.concurrency);
        allTrainingExamples.setAll(i -> this.toMappedId.applyAsLong(allTrainingExamples.get(i)));
        ShuffleUtil.shuffleArray((HugeLongArray)allTrainingExamples, (SplittableRandom)ShuffleUtil.createRandomDataGenerator(randomSeed));
        TrainingExamplesSplit outerSplit = new FractionSplitter().split(ReadOnlyHugeLongArray.of((HugeLongArray)allTrainingExamples), 1.0 - testFraction);
        TrainingSetWarnings.warnForSmallNodeSets(outerSplit.trainSet().size(), outerSplit.testSet().size(), validationFolds, this.progressTracker);
        return ImmutableNodeSplits.of(ReadOnlyHugeLongArray.of((HugeLongArray)allTrainingExamples), outerSplit);
    }

    @ValueClass
    public static interface NodeSplits {
        public ReadOnlyHugeLongArray allTrainingExamples();

        public TrainingExamplesSplit outerSplit();
    }
}

