/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.randomforest;

import java.util.List;
import java.util.PrimitiveIterator;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.randomforest.ImmutableRandomForestClassifierData;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierData;

public class RandomForestClassifier
implements Classifier {
    private final RandomForestClassifierData data;

    public RandomForestClassifier(List<DecisionTreePredictor<Integer>> decisionTrees, int numberOfClasses, int featureDimension) {
        this(ImmutableRandomForestClassifierData.of(featureDimension, numberOfClasses, decisionTrees));
    }

    public RandomForestClassifier(RandomForestClassifierData data) {
        this.data = data;
    }

    public static MemoryRange runtimeOverheadMemoryEstimation(int numberOfClasses) {
        return MemoryRange.of((long)MemoryUsage.sizeOfInstance(RandomForestClassifier.class)).add(MemoryUsage.sizeOfDoubleArray((long)numberOfClasses)).add(MemoryUsage.sizeOfIntArray((long)numberOfClasses));
    }

    @Override
    public Classifier.ClassifierData data() {
        return this.data;
    }

    @Override
    public double[] predictProbabilities(double[] features) {
        int[] votesPerClass = this.gatherTreePredictions(features);
        int numberOfTrees = this.data.decisionTrees().size();
        double[] probabilities = new double[this.numberOfClasses()];
        for (int classIdx = 0; classIdx < votesPerClass.length; ++classIdx) {
            int voteForClass = votesPerClass[classIdx];
            probabilities[classIdx] = (double)voteForClass / (double)numberOfTrees;
        }
        return probabilities;
    }

    @Override
    public Matrix predictProbabilities(Batch batch, Features features) {
        Matrix predictedProbabilities = new Matrix(batch.size(), this.numberOfClasses());
        int offset = 0;
        PrimitiveIterator.OfLong batchIterator = batch.elementIds();
        while (batchIterator.hasNext()) {
            long id = batchIterator.nextLong();
            predictedProbabilities.setRow(offset++, this.predictProbabilities(features.get(id)));
        }
        return predictedProbabilities;
    }

    int[] gatherTreePredictions(double[] features) {
        int[] predictionsPerClass = new int[this.numberOfClasses()];
        for (DecisionTreePredictor<Integer> decisionTree : this.data.decisionTrees()) {
            int predictedClass;
            int n = predictedClass = decisionTree.predict(features).intValue();
            predictionsPerClass[n] = predictionsPerClass[n] + 1;
        }
        return predictionsPerClass;
    }
}

