/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.mlp;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MatrixVectorSum;
import org.neo4j.gds.ml.core.functions.Relu;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.mlp.MLPClassifierData;

public final class MLPClassifier
implements Classifier {
    private final MLPClassifierData data;

    public MLPClassifier(MLPClassifierData data) {
        this.data = data;
    }

    @Override
    public double[] predictProbabilities(double[] features) {
        ComputationContext ctx = new ComputationContext();
        Constant featuresVariable = Constant.matrix((double[])features, (int)1, (int)features.length);
        Variable<Matrix> predictionsVariable = this.predictionsVariable((Constant<Matrix>)featuresVariable);
        double[] steps = ((Matrix)ctx.forward(predictionsVariable)).data();
        return steps;
    }

    @Override
    public Matrix predictProbabilities(Batch batch, Features features) {
        return (Matrix)new ComputationContext().forward(this.predictionsVariable(Objective.batchFeatureMatrix(batch, features)));
    }

    @Override
    public MLPClassifierData data() {
        return this.data;
    }

    Variable<Matrix> predictionsVariable(Constant<Matrix> batchFeatures) {
        Relu inputToNextLayer = batchFeatures;
        for (int i = 0; i < this.data.depth() - 1; ++i) {
            Relu outputFromPrevLayer = inputToNextLayer;
            inputToNextLayer = new Relu((Variable)new MatrixVectorSum((Variable)MatrixMultiplyWithTransposedSecondOperand.of(outputFromPrevLayer, (Variable)((Variable)this.data.weights().get(i))), (Variable)this.data.biases().get(i)), 0.0);
        }
        return new Softmax(inputToNextLayer);
    }
}

