/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.linearregression;

import java.util.List;
import java.util.PrimitiveIterator;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.MeanSquareError;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionData;
import org.neo4j.gds.ml.models.linearregression.LinearRegressor;

public class LinearRegressionObjective
implements Objective<LinearRegressionData> {
    private final Features features;
    private final HugeDoubleArray targets;
    private final LinearRegressionData modelData;
    private final double penalty;

    @Override
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.modelData.weights(), this.modelData.bias());
    }

    LinearRegressionObjective(Features features, HugeDoubleArray targets, double penalty) {
        this.features = features;
        this.targets = targets;
        this.modelData = LinearRegressionData.of(features.featureDimension());
        this.penalty = penalty;
    }

    @Override
    public Variable<Scalar> loss(Batch batch, long trainSize) {
        LinearRegressor regressor = new LinearRegressor(this.modelData);
        Constant<Matrix> batchFeatures = Objective.batchFeatureMatrix(batch, this.features);
        Variable<Matrix> predictionsVariable = regressor.predictionsVariable((Variable<Matrix>)batchFeatures);
        Constant<Vector> batchTargets = this.batchTargets(batch);
        return new ElementSum(List.of(new MeanSquareError(predictionsVariable, batchTargets), this.penaltyForBatch(batch, trainSize)));
    }

    private Variable<Scalar> penaltyForBatch(Batch batch, long trainSize) {
        return new ConstantScale((Variable)new L2NormSquared(this.modelData().weights()), (double)batch.size() * this.penalty / (double)trainSize);
    }

    private Constant<Vector> batchTargets(Batch batch) {
        Vector batchedTargets = new Vector(batch.size());
        int batchOffset = 0;
        PrimitiveIterator.OfLong batchIterator = batch.elementIds();
        while (batchIterator.hasNext()) {
            long elementId = batchIterator.nextLong();
            batchedTargets.setDataAt(batchOffset++, this.targets.get(elementId));
        }
        return new Constant((Tensor)batchedTargets);
    }

    @Override
    public LinearRegressionData modelData() {
        return this.modelData;
    }
}

