/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.linearregression;

import java.util.function.Supplier;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.gradientdescent.Training;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.RegressorTrainer;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionObjective;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionTrainConfig;
import org.neo4j.gds.ml.models.linearregression.LinearRegressor;
import org.neo4j.gds.termination.TerminationFlag;

public final class LinearRegressionTrainer
implements RegressorTrainer {
    private final int concurrency;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;
    private final LogLevel messageLogLevel;
    private final LinearRegressionTrainConfig trainConfig;

    public LinearRegressionTrainer(int concurrency, LinearRegressionTrainConfig config, TerminationFlag terminationFlag, ProgressTracker progressTracker, LogLevel messageLogLevel) {
        this.concurrency = concurrency;
        this.trainConfig = config;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
        this.messageLogLevel = messageLogLevel;
    }

    @Override
    public LinearRegressor train(Features features, HugeDoubleArray targets, ReadOnlyHugeLongArray trainSet) {
        LinearRegressionObjective objective = new LinearRegressionObjective(features, targets, this.trainConfig.penalty());
        Supplier<BatchQueue> queueSupplier = () -> BatchQueue.fromArray((ReadOnlyHugeLongArray)trainSet, (int)this.trainConfig.batchSize());
        Training training = new Training(this.trainConfig, this.progressTracker, this.messageLogLevel, trainSet.size(), this.terminationFlag);
        training.train(objective, queueSupplier, this.concurrency);
        return new LinearRegressor(objective.modelData());
    }
}

