/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.mlp;

import java.util.Optional;
import java.util.SplittableRandom;
import java.util.function.Supplier;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.gradientdescent.Training;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.mlp.MLPClassifier;
import org.neo4j.gds.ml.models.mlp.MLPClassifierData;
import org.neo4j.gds.ml.models.mlp.MLPClassifierObjective;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.termination.TerminationFlag;

public class MLPClassifierTrainer
implements ClassifierTrainer {
    private final int numberOfClasses;
    private final MLPClassifierTrainConfig trainConfig;
    private final SplittableRandom random;
    private final ProgressTracker progressTracker;
    private final LogLevel messageLogLevel;
    private final TerminationFlag terminationFlag;
    private final int concurrency;

    public MLPClassifierTrainer(int numberOfClasses, MLPClassifierTrainConfig trainConfig, Optional<Long> randomSeed, ProgressTracker progressTracker, LogLevel messageLogLevel, TerminationFlag terminationFlag, int concurrency) {
        this.numberOfClasses = numberOfClasses;
        this.trainConfig = trainConfig;
        this.random = new SplittableRandom(randomSeed.orElseGet(() -> new SplittableRandom().nextLong()));
        this.progressTracker = progressTracker;
        this.messageLogLevel = messageLogLevel;
        this.terminationFlag = terminationFlag;
        this.concurrency = concurrency;
    }

    @Override
    public MLPClassifier train(Features features, HugeIntArray labels, ReadOnlyHugeLongArray trainSet) {
        MLPClassifierData data = MLPClassifierData.create(this.numberOfClasses, features.featureDimension(), this.trainConfig.hiddenLayerSizes(), this.random);
        MLPClassifier classifier = new MLPClassifier(data);
        MLPClassifierObjective objective = new MLPClassifierObjective(classifier, features, labels, this.trainConfig.penalty(), this.trainConfig.focusWeight(), this.trainConfig.initializeClassWeights(this.numberOfClasses));
        Training training = new Training(this.trainConfig, this.progressTracker, this.messageLogLevel, trainSet.size(), this.terminationFlag);
        Supplier<BatchQueue> queueSupplier = () -> BatchQueue.fromArray((ReadOnlyHugeLongArray)trainSet, (int)this.trainConfig.batchSize());
        training.train(objective, queueSupplier, this.concurrency);
        return classifier;
    }
}

