/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.linearregression;

import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.EWiseAddMatrixScalar;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionData;

public class LinearRegressor
implements Regressor {
    private final LinearRegressionData data;

    public LinearRegressor(LinearRegressionData data) {
        this.data = data;
    }

    @Override
    public double predict(double[] features) {
        Matrix weights = (Matrix)this.data.weights().data();
        double prediction = 0.0;
        for (int i = 0; i < this.data.featureDimension(); ++i) {
            prediction += weights.dataAt(i) * features[i];
        }
        return prediction + ((Scalar)this.data.bias().data()).value();
    }

    Variable<Matrix> predictionsVariable(Variable<Matrix> features) {
        MatrixMultiplyWithTransposedSecondOperand weightedFeatures = new MatrixMultiplyWithTransposedSecondOperand(features, this.data.weights());
        return new EWiseAddMatrixScalar((Variable)weightedFeatures, this.data.bias());
    }

    @Override
    public LinearRegressionData data() {
        return this.data;
    }
}

