/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.logisticregression;

import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MatrixVectorSum;
import org.neo4j.gds.ml.core.functions.ReducedSoftmax;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionData;

public final class LogisticRegressionClassifier
implements Classifier {
    private final LogisticRegressionData data;
    private final LogisticRegressionPredictionStrategy predictionStrategy;

    private LogisticRegressionClassifier(LogisticRegressionData data, LogisticRegressionPredictionStrategy predictionStrategy) {
        this.data = data;
        this.predictionStrategy = predictionStrategy;
    }

    public static LogisticRegressionClassifier from(LogisticRegressionData data) {
        LogisticRegressionPredictionStrategy predictionStrategy = data.numberOfClasses() == 2 && ((Matrix)data.weights().data()).rows() == 1 ? LogisticRegressionPredictionStrategy.binary() : LogisticRegressionPredictionStrategy.multiClass();
        return new LogisticRegressionClassifier(data, predictionStrategy);
    }

    public static long sizeOfPredictionsVariableInBytes(int batchSize, int numberOfFeatures, int numberOfClasses, int normalizedNumberOfClasses) {
        int[] dimensionsOfFirstMatrix = Dimensions.matrix((int)batchSize, (int)numberOfFeatures);
        long softmaxSize = numberOfClasses == normalizedNumberOfClasses ? Softmax.sizeInBytes((int)batchSize, (int)numberOfClasses) : ReducedSoftmax.sizeInBytes((int)batchSize, (int)numberOfClasses);
        return LogisticRegressionClassifier.sizeOfFeatureExtractorsInBytes(numberOfFeatures) + Constant.sizeInBytes((int[])dimensionsOfFirstMatrix) + MatrixMultiplyWithTransposedSecondOperand.sizeInBytes((int)batchSize, (int)normalizedNumberOfClasses) + softmaxSize;
    }

    public static MemoryRange runtimeOverheadMemoryEstimation(int batchSize, int featureDimension, int numberOfClasses, boolean isReduced) {
        int normalizedNumberOfClasses = isReduced ? numberOfClasses - 1 : numberOfClasses;
        return MemoryRange.of((long)LogisticRegressionClassifier.sizeOfPredictionsVariableInBytes(batchSize, featureDimension, numberOfClasses, normalizedNumberOfClasses));
    }

    private static long sizeOfFeatureExtractorsInBytes(int numberOfFeatures) {
        return FeatureExtraction.memoryUsageInBytes((int)numberOfFeatures);
    }

    @Override
    public double[] predictProbabilities(double[] features) {
        return this.predictionStrategy.predictProbabilities(features, this);
    }

    @Override
    public Matrix predictProbabilities(Batch batch, Features features) {
        ComputationContext ctx = new ComputationContext();
        return (Matrix)ctx.forward(this.predictionsVariable(Objective.batchFeatureMatrix(batch, features)));
    }

    Variable<Matrix> predictionsVariable(Constant<Matrix> batchFeatures) {
        Weights<Matrix> weights = this.data.weights();
        MatrixMultiplyWithTransposedSecondOperand weightedFeatures = MatrixMultiplyWithTransposedSecondOperand.of(batchFeatures, weights);
        MatrixVectorSum softmaxInput = new MatrixVectorSum((Variable)weightedFeatures, this.data.bias());
        return ((Matrix)weights.data()).rows() == this.numberOfClasses() ? new Softmax((Variable)softmaxInput) : new ReducedSoftmax((Variable)softmaxInput);
    }

    @Override
    public LogisticRegressionData data() {
        return this.data;
    }

    static interface LogisticRegressionPredictionStrategy {
        public double[] predictProbabilities(double[] var1, LogisticRegressionClassifier var2);

        public static LogisticRegressionPredictionStrategy binary() {
            return (features, classifier) -> {
                double affinity = 0.0;
                Matrix weights = (Matrix)classifier.data().weights().data();
                for (int i = 0; i < features.length; ++i) {
                    affinity += weights.dataAt(i) * features[i];
                }
                double sigmoid = Sigmoid.sigmoid((double)(affinity + ((Vector)classifier.data().bias().data()).dataAt(0)));
                return new double[]{sigmoid, 1.0 - sigmoid};
            };
        }

        public static LogisticRegressionPredictionStrategy multiClass() {
            return (features, classifier) -> {
                ComputationContext ctx = new ComputationContext();
                Constant featuresVariable = Constant.matrix((double[])features, (int)1, (int)features.length);
                Variable<Matrix> predictionsVariable = classifier.predictionsVariable((Constant<Matrix>)featuresVariable);
                return ((Matrix)ctx.forward(predictionsVariable)).data();
            };
        }
    }
}

