/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.logisticregression;

import java.util.function.LongUnaryOperator;
import java.util.function.Supplier;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.gradientdescent.Training;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionClassifier;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionData;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionObjective;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.termination.TerminationFlag;

public final class LogisticRegressionTrainer
implements ClassifierTrainer {
    private final LogisticRegressionTrainConfig trainConfig;
    private final int numberOfClasses;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final boolean reduceClassCount;
    private final LogLevel messageLogLevel;
    private final int concurrency;

    public static MemoryEstimation memoryEstimation(boolean isReduced, int numberOfClasses, MemoryRange featureDimension, int batchSize, LongUnaryOperator numberOfTrainingExamples) {
        return MemoryEstimations.builder((String)"train logistic regression").add("model data", LogisticRegressionData.memoryEstimation(isReduced, numberOfClasses, featureDimension)).add("update weights", Training.memoryEstimation(featureDimension, numberOfClasses)).perGraphDimension("computation graph", (graphDimensions, concurrency) -> {
            long actualTrainSetSize = numberOfTrainingExamples.applyAsLong(graphDimensions.nodeCount());
            int numberOfConcurrentComputationGraphs = (int)Math.min((double)concurrency.intValue(), Math.ceil((double)actualTrainSetSize / (double)batchSize));
            return featureDimension.apply(dim -> LogisticRegressionTrainer.sizeInBytesOfComputationGraph(isReduced, batchSize, (int)dim, numberOfClasses)).times((long)numberOfConcurrentComputationGraphs);
        }).build();
    }

    private static long sizeInBytesOfComputationGraph(boolean isReduced, int batchSize, int numberOfFeatures, int numberOfClasses) {
        return LogisticRegressionObjective.sizeOfBatchInBytes(isReduced, batchSize, numberOfFeatures, numberOfClasses);
    }

    public LogisticRegressionTrainer(int concurrency, LogisticRegressionTrainConfig trainConfig, int numberOfClasses, boolean reduceClassCount, TerminationFlag terminationFlag, ProgressTracker progressTracker, LogLevel messageLogLevel) {
        this.concurrency = concurrency;
        this.trainConfig = trainConfig;
        this.numberOfClasses = numberOfClasses;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.reduceClassCount = reduceClassCount;
        this.messageLogLevel = messageLogLevel;
    }

    @Override
    public LogisticRegressionClassifier train(Features features, HugeIntArray labels, ReadOnlyHugeLongArray trainSet) {
        LogisticRegressionData data = this.reduceClassCount ? LogisticRegressionData.withReducedClassCount(features.featureDimension(), this.numberOfClasses) : LogisticRegressionData.standard(features.featureDimension(), this.numberOfClasses);
        LogisticRegressionClassifier classifier = LogisticRegressionClassifier.from(data);
        LogisticRegressionObjective objective = new LogisticRegressionObjective(classifier, this.trainConfig.penalty(), features, labels, this.trainConfig.focusWeight(), this.trainConfig.initializeClassWeights(this.numberOfClasses));
        Training training = new Training(this.trainConfig, this.progressTracker, this.messageLogLevel, trainSet.size(), this.terminationFlag);
        Supplier<BatchQueue> queueSupplier = () -> BatchQueue.fromArray((ReadOnlyHugeLongArray)trainSet, (int)this.trainConfig.batchSize());
        training.train(objective, queueSupplier, this.concurrency);
        return classifier;
    }
}

