/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.mlp;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.SplittableRandom;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.mlp.ImmutableMLPClassifierData;

@ValueClass
public interface MLPClassifierData
extends Classifier.ClassifierData,
Serializable {
    public List<Weights<Matrix>> weights();

    public List<Weights<Vector>> biases();

    @Value.Derived
    default public int depth() {
        return this.biases().size() + 1;
    }

    @Override
    @Value.Derived
    default public int numberOfClasses() {
        return this.biases().get(this.biases().size() - 1).dimension(0);
    }

    @Override
    @Value.Derived
    default public int featureDimension() {
        return this.weights().get(0).dimension(1);
    }

    @Override
    default public TrainingMethod trainerMethod() {
        return TrainingMethod.MLPClassification;
    }

    public static MLPClassifierData create(int classCount, int featureCount, List<Integer> hiddenLayerSizes, SplittableRandom random) {
        ArrayList<Weights<Matrix>> weights = new ArrayList<Weights<Matrix>>();
        ArrayList<Weights<Vector>> biases = new ArrayList<Weights<Vector>>();
        int hiddenDepth = hiddenLayerSizes.size();
        weights.add(MLPClassifierData.generateWeights(hiddenLayerSizes.get(0), featureCount, random.nextLong()));
        biases.add(MLPClassifierData.generateBias(hiddenLayerSizes.get(0), random.nextLong()));
        for (int i = 0; i < hiddenDepth - 1; ++i) {
            weights.add(MLPClassifierData.generateWeights(hiddenLayerSizes.get(i + 1), hiddenLayerSizes.get(i), random.nextLong()));
            biases.add(MLPClassifierData.generateBias(hiddenLayerSizes.get(i + 1), random.nextLong()));
        }
        weights.add(MLPClassifierData.generateWeights(classCount, hiddenLayerSizes.get(hiddenDepth - 1), random.nextLong()));
        biases.add(MLPClassifierData.generateBias(classCount, random.nextLong()));
        return ImmutableMLPClassifierData.builder().weights(weights).biases(biases).build();
    }

    private static Weights<Matrix> generateWeights(int rows, int cols, long randomSeed) {
        double weightBound = Math.sqrt(2.0 / (double)cols);
        double[] data = new Random(randomSeed).doubles(Math.multiplyExact(rows, cols), -weightBound, weightBound).toArray();
        return new Weights((Tensor)new Matrix(data, rows, cols));
    }

    private static Weights<Vector> generateBias(int dim, long randomSeed) {
        double weightBound = Math.sqrt(2.0 / (double)dim);
        double[] data = new Random(randomSeed).doubles(dim, -weightBound, weightBound).toArray();
        return new Weights((Tensor)new Vector(data));
    }

    public static ImmutableMLPClassifierData.Builder builder() {
        return ImmutableMLPClassifierData.builder();
    }
}

