/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.regression;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.Trainable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.opensearch.ml.engine.utils.TribuoUtil;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.sgd.LabelObjective;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.objectives.Hinge;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.AdaDelta;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.math.optimisers.Adam;
import org.tribuo.math.optimisers.RMSProp;
import org.tribuo.math.optimisers.SGD;

@Function(value=FunctionName.LOGISTIC_REGRESSION)
public class LogisticRegression
implements Trainable,
Predictable {
    public static final String VERSION = "1.0.0";
    private static final LogisticRegressionParams.ObjectiveType DEFAULT_OBJECTIVE_TYPE = LogisticRegressionParams.ObjectiveType.LOGMULTICLASS;
    private static final LogisticRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LogisticRegressionParams.OptimizerType.ADA_GRAD;
    private static final LogisticRegressionParams.MomentumType DEFAULT_MOMENTUM_TYPE = LogisticRegressionParams.MomentumType.STANDARD;
    private static final double DEFAULT_LEARNING_RATE = 1.0;
    private static final double DEFAULT_EPSILON = 0.1;
    private static final int DEFAULT_EPOCHS = 5;
    private static final int DEFAULT_LOGGING_INTERVAL = 1000;
    private static final int DEFAULT_BATCH_SIZE = 1;
    private static final Long DEFAULT_SEED = 12345L;
    private static final double DEFAULT_MOMENTUM_FACTOR = 0.0;
    private static final double DEFAULT_BETA1 = 0.9;
    private static final double DEFAULT_BETA2 = 0.99;
    private static final double DEFAULT_DECAY_RATE = 0.9;
    private int epochs;
    private int loggingInterval;
    private int minibatchSize;
    private long seed;
    private LogisticRegressionParams parameters;
    private StochasticGradientOptimiser optimiser;
    private LabelObjective objective;
    private Model<Label> classificationModel;

    public LogisticRegression(MLAlgoParams parameters) {
        this.parameters = parameters == null ? LogisticRegressionParams.builder().build() : (LogisticRegressionParams)parameters;
        this.validateParameters();
        this.createObjective();
        this.createOptimiser();
    }

    private void validateParameters() {
        if (this.parameters.getLearningRate() != null && this.parameters.getLearningRate() < 0.0) {
            throw new IllegalArgumentException("Learning rate should not be negative.");
        }
        if (this.parameters.getEpsilon() != null && this.parameters.getEpsilon() < 0.0) {
            throw new IllegalArgumentException("Epsilon should not be negative.");
        }
        if (this.parameters.getEpochs() != null && this.parameters.getEpochs() < 0) {
            throw new IllegalArgumentException("Epochs should not be negative.");
        }
        if (this.parameters.getBatchSize() != null && this.parameters.getBatchSize() < 0) {
            throw new IllegalArgumentException("MiniBatchSize should not be negative.");
        }
        if (this.parameters.getLoggingInterval() != null && this.parameters.getLoggingInterval() < -1) {
            throw new IllegalArgumentException("Invalid Logging intervals");
        }
        this.epochs = Optional.ofNullable(this.parameters.getEpochs()).orElse(5);
        this.loggingInterval = Optional.ofNullable(this.parameters.getLoggingInterval()).orElse(1000);
        this.minibatchSize = Optional.ofNullable(this.parameters.getBatchSize()).orElse(1);
        this.seed = Optional.ofNullable(this.parameters.getSeed()).orElse(DEFAULT_SEED);
    }

    private void createObjective() {
        LogisticRegressionParams.ObjectiveType objectiveType = Optional.ofNullable(this.parameters.getObjectiveType()).orElse(DEFAULT_OBJECTIVE_TYPE);
        switch (objectiveType) {
            case HINGE: {
                this.objective = new Hinge();
                break;
            }
            default: {
                this.objective = new LogMulticlass();
            }
        }
    }

    private void createOptimiser() {
        SGD.Momentum momentum;
        LogisticRegressionParams.OptimizerType optimizerType = Optional.ofNullable(this.parameters.getOptimizerType()).orElse(DEFAULT_OPTIMIZER_TYPE);
        Double learningRate = Optional.ofNullable(this.parameters.getLearningRate()).orElse(1.0);
        Double epsilon = Optional.ofNullable(this.parameters.getEpsilon()).orElse(0.1);
        Double momentumFactor = Optional.ofNullable(this.parameters.getMomentumFactor()).orElse(0.0);
        LogisticRegressionParams.MomentumType momentumType = Optional.ofNullable(this.parameters.getMomentumType()).orElse(DEFAULT_MOMENTUM_TYPE);
        Double beta1 = Optional.ofNullable(this.parameters.getBeta1()).orElse(0.9);
        Double beta2 = Optional.ofNullable(this.parameters.getBeta2()).orElse(0.99);
        Double decayRate = Optional.ofNullable(this.parameters.getDecayRate()).orElse(0.9);
        switch (momentumType) {
            case NESTEROV: {
                momentum = SGD.Momentum.NESTEROV;
                break;
            }
            default: {
                momentum = SGD.Momentum.STANDARD;
            }
        }
        switch (optimizerType) {
            case LINEAR_DECAY_SGD: {
                this.optimiser = SGD.getLinearDecaySGD((double)learningRate, (double)momentumFactor, (SGD.Momentum)momentum);
                break;
            }
            case SQRT_DECAY_SGD: {
                this.optimiser = SGD.getSqrtDecaySGD((double)learningRate, (double)momentumFactor, (SGD.Momentum)momentum);
                break;
            }
            case ADA_DELTA: {
                this.optimiser = new AdaDelta(momentumFactor.doubleValue(), epsilon.doubleValue());
                break;
            }
            case ADAM: {
                this.optimiser = new Adam(learningRate.doubleValue(), beta1.doubleValue(), beta2.doubleValue(), epsilon.doubleValue());
                break;
            }
            case RMS_PROP: {
                this.optimiser = new RMSProp(learningRate.doubleValue(), momentumFactor.doubleValue(), epsilon.doubleValue(), decayRate.doubleValue());
                break;
            }
            case SIMPLE_SGD: {
                this.optimiser = SGD.getSimpleSGD((double)learningRate, (double)momentumFactor, (SGD.Momentum)momentum);
                break;
            }
            default: {
                this.optimiser = new AdaGrad(learningRate.doubleValue(), epsilon.doubleValue());
            }
        }
    }

    @Override
    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new LabelFactory(), "Logistic regression training data from OpenSearch", TribuoOutputType.LABEL, this.parameters.getTarget());
        LinearSGDTrainer logisticRegressionTrainer = new LinearSGDTrainer(this.objective, this.optimiser, this.epochs, this.loggingInterval, this.minibatchSize, this.seed);
        Model classificationModel = logisticRegressionTrainer.train(trainDataset);
        MLModel model = MLModel.builder().name(FunctionName.LOGISTIC_REGRESSION.name()).algorithm(FunctionName.LOGISTIC_REGRESSION).version(VERSION).content(ModelSerDeSer.serializeToBase64(classificationModel)).modelState(MLModelState.TRAINED).build();
        return model;
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params) {
        this.classificationModel = (Model)ModelSerDeSer.deserialize(model);
    }

    @Override
    public void close() {
        this.classificationModel = null;
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new LabelFactory(), "Logistic regression prediction data from OpenSearch", TribuoOutputType.LABEL);
        List predictions = this.classificationModel.predict(predictionDataset);
        ArrayList listPrediction = new ArrayList();
        predictions.forEach(e -> listPrediction.add(Collections.singletonMap("result", ((Label)e.getOutput()).getLabel())));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
    }

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for logistic regression prediction.");
        }
        this.classificationModel = (Model)ModelSerDeSer.deserialize(model);
        return this.predict(mlInput);
    }
}

