/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import java.util.List;
import java.util.Optional;
import java.util.SortedSet;
import java.util.SplittableRandom;
import java.util.function.ToLongFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.mutable.MutableInt;
import org.eclipse.collections.api.block.function.primitive.LongToLongFunction;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.shuffle.ShuffleUtil;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;

public class StratifiedKFoldSplitter {
    private final int k;
    private final ReadOnlyHugeLongArray ids;
    private final LongToLongFunction targets;
    private final SplittableRandom random;
    private final SortedSet<Long> distinctInternalTargets;

    public static MemoryEstimation memoryEstimationForNodeSet(int k, double trainFraction) {
        return StratifiedKFoldSplitter.memoryEstimation(k, dim -> (long)((double)dim.nodeCount() * trainFraction));
    }

    public static MemoryEstimation memoryEstimation(int k, ToLongFunction<GraphDimensions> idsSetSizeExtractor) {
        return MemoryEstimations.setup((String)"", dimensions -> {
            long idSetSize = idsSetSizeExtractor.applyAsLong((GraphDimensions)dimensions);
            MemoryEstimations.Builder builder = MemoryEstimations.builder((String)StratifiedKFoldSplitter.class.getSimpleName());
            long baseBucketSize = idSetSize / (long)k;
            for (int fold = 0; fold < k; ++fold) {
                long testSize = (long)fold < idSetSize % (long)k ? baseBucketSize + 1L : baseBucketSize;
                long trainSize = idSetSize - testSize;
                builder.add("Fold " + fold, MemoryEstimations.builder().add(MemoryEstimations.of((String)"Test", (MemoryRange)MemoryRange.of((long)HugeLongArray.memoryEstimation((long)testSize)))).add(MemoryEstimations.of((String)"Train", (MemoryRange)MemoryRange.of((long)HugeLongArray.memoryEstimation((long)trainSize)))).build());
            }
            return builder.build();
        });
    }

    public StratifiedKFoldSplitter(int k, ReadOnlyHugeLongArray ids, LongToLongFunction targets, Optional<Long> randomSeed, SortedSet<Long> distinctInternalTargets) {
        this.k = k;
        this.ids = ids;
        this.targets = targets;
        this.random = ShuffleUtil.createRandomDataGenerator(randomSeed);
        this.distinctInternalTargets = distinctInternalTargets;
    }

    public List<TrainingExamplesSplit> splits() {
        long nodeCount = this.ids.size();
        HugeLongArray[] trainSets = new HugeLongArray[this.k];
        HugeLongArray[] testSets = new HugeLongArray[this.k];
        int[] trainNodesAdded = new int[this.k];
        int[] testNodesAdded = new int[this.k];
        this.allocateArrays(nodeCount, trainSets, testSets);
        MutableInt roundRobinPointer = new MutableInt();
        this.distinctInternalTargets.forEach(currentClass -> {
            for (long offset = 0L; offset < this.ids.size(); ++offset) {
                long id = this.ids.get(offset);
                if (this.targets.applyAsLong(id) != currentClass.longValue()) continue;
                Integer bucketToAddTo = roundRobinPointer.getValue();
                for (int fold = 0; fold < this.k; ++fold) {
                    if (fold == bucketToAddTo) {
                        testSets[fold].set((long)testNodesAdded[fold], id);
                        int n = fold;
                        testNodesAdded[n] = testNodesAdded[n] + 1;
                        continue;
                    }
                    trainSets[fold].set((long)trainNodesAdded[fold], id);
                    int n = fold;
                    trainNodesAdded[n] = trainNodesAdded[n] + 1;
                }
                roundRobinPointer.setValue((bucketToAddTo + 1) % this.k);
            }
        });
        return IntStream.range(0, this.k).mapToObj(fold -> {
            ShuffleUtil.shuffleArray((HugeLongArray)trainSets[fold], (SplittableRandom)this.random);
            ShuffleUtil.shuffleArray((HugeLongArray)testSets[fold], (SplittableRandom)this.random);
            return TrainingExamplesSplit.of(ReadOnlyHugeLongArray.of((HugeLongArray)trainSets[fold]), ReadOnlyHugeLongArray.of((HugeLongArray)testSets[fold]));
        }).collect(Collectors.toList());
    }

    private void allocateArrays(long nodeCount, HugeLongArray[] trainSets, HugeLongArray[] testSets) {
        int baseBucketSize = (int)nodeCount / this.k;
        for (int fold = 0; fold < this.k; ++fold) {
            int testSize = (long)fold < nodeCount % (long)this.k ? baseBucketSize + 1 : baseBucketSize;
            testSets[fold] = HugeLongArray.newArray((long)testSize);
            trainSets[fold] = HugeLongArray.newArray((long)(nodeCount - (long)testSize));
        }
    }
}

