/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.logisticregression;

import java.util.List;
import java.util.PrimitiveIterator;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.ReducedCrossEntropyLoss;
import org.neo4j.gds.ml.core.functions.ReducedFocalLoss;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionClassifier;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionData;

public class LogisticRegressionObjective
implements Objective<LogisticRegressionData> {
    private final LogisticRegressionClassifier classifier;
    private final double penalty;
    private final Features features;
    private final HugeIntArray labels;
    private final double focusWeight;
    private final double[] classWeights;

    public static long sizeOfBatchInBytes(boolean isReduced, int batchSize, int numberOfFeatures, int numberOfClasses) {
        long l2norm;
        long constantScale;
        int normalizedNumberOfClasses = isReduced ? numberOfClasses - 1 : numberOfClasses;
        long batchLocalWeightGradient = Weights.sizeInBytes((int)normalizedNumberOfClasses, (int)numberOfFeatures);
        long targets = Matrix.sizeInBytes((int)batchSize, (int)1);
        long weightedFeatures = MatrixMultiplyWithTransposedSecondOperand.sizeInBytes((int)batchSize, (int)normalizedNumberOfClasses);
        long softMax = Softmax.sizeInBytes((int)batchSize, (int)numberOfClasses);
        long unpenalizedLoss = ReducedCrossEntropyLoss.sizeInBytes();
        long elementSum = constantScale = (l2norm = L2NormSquared.sizeInBytesOfApply());
        long sizeOfPredictionsVariableInBytes = LogisticRegressionClassifier.sizeOfPredictionsVariableInBytes(batchSize, numberOfFeatures, numberOfClasses, normalizedNumberOfClasses);
        long sizeOfComputationGraphForTrainEpoch = 1L * targets + 1L * weightedFeatures + 1L * softMax + 2L * unpenalizedLoss + 2L * l2norm + 2L * constantScale + 2L * elementSum + sizeOfPredictionsVariableInBytes + batchLocalWeightGradient;
        long sizeOfComputationGraphForEvaluateLoss = 1L * targets + 1L * weightedFeatures + 1L * softMax + 1L * unpenalizedLoss + 1L * l2norm + 1L * constantScale + 1L * elementSum + sizeOfPredictionsVariableInBytes;
        return Math.max(sizeOfComputationGraphForTrainEpoch, sizeOfComputationGraphForEvaluateLoss);
    }

    public LogisticRegressionObjective(LogisticRegressionClassifier classifier, double penalty, Features features, HugeIntArray labels, double focusWeight, double[] classWeights) {
        this.classifier = classifier;
        this.penalty = penalty;
        this.features = features;
        this.labels = labels;
        this.focusWeight = focusWeight;
        this.classWeights = classWeights;
        assert (features.size() > 0L);
    }

    @Override
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.classifier.data().weights(), this.classifier.data().bias());
    }

    @Override
    public Variable<Scalar> loss(Batch batch, long trainSize) {
        ReducedCrossEntropyLoss unpenalizedLoss = this.crossEntropyLoss(batch);
        ConstantScale<Scalar> penaltyVariable = this.penaltyForBatch(batch, trainSize);
        return new ElementSum(List.of(unpenalizedLoss, penaltyVariable));
    }

    ConstantScale<Scalar> penaltyForBatch(Batch batch, long trainSize) {
        return new ConstantScale((Variable)new L2NormSquared(this.modelData().weights()), (double)batch.size() * this.penalty / (double)trainSize);
    }

    ReducedCrossEntropyLoss crossEntropyLoss(Batch batch) {
        Constant<Vector> batchLabels = this.batchLabelVector(batch);
        Constant<Matrix> batchFeatures = Objective.batchFeatureMatrix(batch, this.features);
        Variable<Matrix> predictions = this.classifier.predictionsVariable(batchFeatures);
        if (this.focusWeight == 0.0) {
            return new ReducedCrossEntropyLoss(predictions, this.classifier.data().weights(), this.classifier.data().bias(), batchFeatures, batchLabels, this.classWeights);
        }
        return new ReducedFocalLoss(predictions, this.classifier.data().weights(), this.classifier.data().bias(), batchFeatures, batchLabels, this.focusWeight, this.classWeights);
    }

    @Override
    public LogisticRegressionData modelData() {
        return this.classifier.data();
    }

    Constant<Vector> batchLabelVector(Batch batch) {
        Vector batchedTargets = new Vector(batch.size());
        int batchOffset = 0;
        PrimitiveIterator.OfLong batchIterator = batch.elementIds();
        while (batchIterator.hasNext()) {
            long elementId = batchIterator.nextLong();
            batchedTargets.setDataAt(batchOffset++, (double)this.labels.get(elementId));
        }
        return new Constant((Tensor)batchedTargets);
    }
}

