/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfig;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.Group;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.decisiontree.SplitMeanSquaredError;
import org.neo4j.gds.ml.decisiontree.TreeNode;
import org.neo4j.gds.ml.models.Features;

public class DecisionTreeRegressorTrainer
extends DecisionTreeTrainer<Double> {
    private final HugeDoubleArray targets;

    public DecisionTreeRegressorTrainer(ImpurityCriterion impurityCriterion, Features features, HugeDoubleArray targets, DecisionTreeTrainerConfig config, FeatureBagger featureBagger) {
        super(features, config, impurityCriterion, featureBagger);
        assert (targets.size() == features.size());
        this.targets = targets;
    }

    public static MemoryRange memoryEstimation(DecisionTreeTrainerConfig config, long numberOfTrainingSamples) {
        return MemoryRange.of((long)MemoryUsage.sizeOfInstance(DecisionTreeRegressorTrainer.class)).add(DecisionTreeTrainer.estimateTree(config, numberOfTrainingSamples, TreeNode.leafMemoryEstimation(Double.class), SplitMeanSquaredError.MSEImpurityData.memoryEstimation()));
    }

    @Override
    protected Double toTerminal(Group group) {
        HugeLongArray array = group.array();
        double sum = 0.0;
        for (long i = group.startIdx(); i < group.startIdx() + group.size(); ++i) {
            sum += this.targets.get(array.get(i));
        }
        return sum / (double)group.size();
    }
}

