/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.logisticregression;

import java.io.Serializable;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.logisticregression.ImmutableLogisticRegressionData;

@ValueClass
public interface LogisticRegressionData
extends Classifier.ClassifierData,
Serializable {
    public Weights<Matrix> weights();

    public Weights<Vector> bias();

    @Override
    @Value.Derived
    default public TrainingMethod trainerMethod() {
        return TrainingMethod.LogisticRegression;
    }

    @Override
    @Value.Derived
    default public int featureDimension() {
        return this.weights().dimension(1);
    }

    public static LogisticRegressionData standard(int featureCount, int numberOfClasses) {
        return LogisticRegressionData.create(numberOfClasses, featureCount, false);
    }

    public static LogisticRegressionData withReducedClassCount(int featureCount, int numberOfClasses) {
        return LogisticRegressionData.create(numberOfClasses, featureCount, true);
    }

    private static LogisticRegressionData create(int classCount, int featureCount, boolean skipLastClass) {
        int rows = skipLastClass ? classCount - 1 : classCount;
        Weights weights = Weights.ofMatrix((int)rows, (int)featureCount);
        Weights bias = new Weights((Tensor)new Vector(rows));
        return ImmutableLogisticRegressionData.builder().weights((Weights<Matrix>)weights).numberOfClasses(classCount).bias((Weights<Vector>)bias).build();
    }

    public static MemoryEstimation memoryEstimation(boolean isReduced, int numberOfClasses, MemoryRange featureDimension) {
        int normalizedNumberOfClasses = isReduced ? numberOfClasses - 1 : numberOfClasses;
        return MemoryEstimations.builder((String)"Logistic regression model data").fixed("weights", featureDimension.apply(featureDim -> Weights.sizeInBytes((int)normalizedNumberOfClasses, (int)Math.toIntExact(featureDim)))).fixed("bias", Weights.sizeInBytes((int)normalizedNumberOfClasses, (int)1)).build();
    }

    public static ImmutableLogisticRegressionData.Builder builder() {
        return ImmutableLogisticRegressionData.builder();
    }
}

