/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models;

import java.util.Optional;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainer;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.termination.TerminationFlag;

public final class ClassifierTrainerFactory {
    private ClassifierTrainerFactory() {
    }

    public static ClassifierTrainer create(TrainerConfig config, int numberOfClasses, TerminationFlag terminationFlag, ProgressTracker progressTracker, LogLevel messageLogLevel, int concurrency, Optional<Long> randomSeed, boolean reduceClassCount, ModelSpecificMetricsHandler metricsHandler) {
        switch (config.method()) {
            case LogisticRegression: {
                return new LogisticRegressionTrainer(concurrency, (LogisticRegressionTrainConfig)config, numberOfClasses, reduceClassCount, terminationFlag, progressTracker, messageLogLevel);
            }
            case RandomForestClassification: {
                return new RandomForestClassifierTrainer(concurrency, numberOfClasses, (RandomForestClassifierTrainerConfig)config, randomSeed, progressTracker, messageLogLevel, terminationFlag, metricsHandler);
            }
            case MLPClassification: {
                return new MLPClassifierTrainer(numberOfClasses, (MLPClassifierTrainConfig)config, randomSeed, progressTracker, messageLogLevel, terminationFlag, concurrency);
            }
        }
        throw new IllegalStateException("No such training method.");
    }

    public static MemoryEstimation memoryEstimation(TrainerConfig config, LongUnaryOperator numberOfTrainingExamples, int numberOfClasses, MemoryRange featureDimension, boolean isReduced) {
        switch (config.method()) {
            case LogisticRegression: {
                return LogisticRegressionTrainer.memoryEstimation(isReduced, numberOfClasses, featureDimension, ((LogisticRegressionTrainConfig)config).batchSize(), numberOfTrainingExamples);
            }
            case RandomForestClassification: {
                return RandomForestClassifierTrainer.memoryEstimation(numberOfTrainingExamples, numberOfClasses, featureDimension, (RandomForestClassifierTrainerConfig)config);
            }
            case MLPClassification: {
                return MemoryEstimations.empty();
            }
        }
        throw new IllegalStateException("No such training method.");
    }
}

