package org.tribuo.common.tree;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.SplittableRandom;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.SparseModel;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/common/tree/AbstractCARTTrainer.class */
public abstract class AbstractCARTTrainer<T extends Output<T>> implements DecisionTreeTrainer<T> {
    public static final int MIN_EXAMPLES = 5;

    @Config(description = "The minimum weight allowed in a child node.")
    protected float minChildWeight;

    @Config(description = "The maximum depth of the tree.")
    protected int maxDepth;

    @Config(description = "The decrease in impurity needed in order to split the node.")
    protected float minImpurityDecrease;

    @Config(description = "The fraction of features to consider in each split. 1.0f indicates all features are considered.")
    protected float fractionFeaturesInSplit;

    @Config(description = "Whether to choose split points for features at random.")
    protected boolean useRandomSplitPoints;

    @Config(description = "The RNG seed to use when sampling features in a split.")
    protected long seed;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    @Deprecated
    /* loaded from: input_file:org/tribuo/common/tree/AbstractCARTTrainer$AbstractCARTTrainerProvenance.class */
    protected static abstract class AbstractCARTTrainerProvenance extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1;

        protected <T extends Output<T>> AbstractCARTTrainerProvenance(AbstractCARTTrainer<T> abstractCARTTrainer) {
            super(abstractCARTTrainer);
        }

        protected AbstractCARTTrainerProvenance(Map<String, Provenance> map) {
            super(map);
        }
    }

    protected AbstractCARTTrainer(int i, float f, float f2, float f3, boolean z, long j) {
        this.minChildWeight = 5.0f;
        this.maxDepth = Integer.MAX_VALUE;
        this.minImpurityDecrease = 0.0f;
        this.fractionFeaturesInSplit = 1.0f;
        this.useRandomSplitPoints = false;
        this.seed = 12345L;
        this.maxDepth = i;
        this.fractionFeaturesInSplit = f3;
        this.useRandomSplitPoints = z;
        this.minChildWeight = f;
        this.minImpurityDecrease = f2;
        this.seed = j;
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
        if (this.fractionFeaturesInSplit <= 0.0f || this.fractionFeaturesInSplit > 1.0f) {
            throw new IllegalArgumentException("fractionFeaturesInSplit must be greater than 0 and less than or equal to 1");
        }
        if (this.minImpurityDecrease < 0.0f) {
            throw new IllegalArgumentException("minImpurityDecrease must be greater than or equal to 0");
        }
        if (this.maxDepth < 0) {
            throw new IllegalArgumentException("maxDepth must be non-negative");
        }
        if (this.minChildWeight <= 0.0f) {
            throw new IllegalArgumentException("minChildWeight must be greater than 0");
        }
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < i) {
            this.rng.split();
            this.trainInvocationCounter++;
        }
    }

    @Override // org.tribuo.common.tree.DecisionTreeTrainer
    public float getFractionFeaturesInSplit() {
        return this.fractionFeaturesInSplit;
    }

    @Override // org.tribuo.common.tree.DecisionTreeTrainer
    public boolean getUseRandomSplitPoints() {
        return this.useRandomSplitPoints;
    }

    @Override // org.tribuo.common.tree.DecisionTreeTrainer
    public float getMinImpurityDecrease() {
        return this.minImpurityDecrease;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public TreeModel<T> m5train(Dataset<T> dataset) {
        return train((Dataset) dataset, Collections.emptyMap());
    }

    public TreeModel<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        return train((Dataset) dataset, map, -1);
    }

    public TreeModel<T> train(Dataset<T> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance provenance;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            provenance = getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        int min = Math.min(Math.round(this.fractionFeaturesInSplit * featureIDMap.size()), featureIDMap.size());
        int[] iArr = new int[featureIDMap.size()];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = i2;
        }
        int[] iArr2 = min != featureIDMap.size() ? new int[min] : iArr;
        float f = 0.0f;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            f += ((Example) it.next()).getWeight();
        }
        AbstractTrainingNode<T> mkTrainingNode = mkTrainingNode(dataset, new AbstractTrainingNode.LeafDeterminer(this.maxDepth, this.minChildWeight, getMinImpurityDecrease() * f));
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.add(mkTrainingNode);
        while (!arrayDeque.isEmpty()) {
            AbstractTrainingNode abstractTrainingNode = (AbstractTrainingNode) arrayDeque.poll();
            if (abstractTrainingNode.getImpurity() > 0.0d && abstractTrainingNode.getDepth() < this.maxDepth && abstractTrainingNode.getWeightSum() >= this.minChildWeight) {
                if (min != featureIDMap.size()) {
                    Util.randpermInPlace(iArr, split);
                    System.arraycopy(iArr, 0, iArr2, 0, min);
                }
                Iterator<AbstractTrainingNode<T>> it2 = abstractTrainingNode.buildTree(iArr2, split, getUseRandomSplitPoints()).iterator();
                while (it2.hasNext()) {
                    arrayDeque.addFirst(it2.next());
                }
            }
        }
        return new TreeModel<>("cart-tree", new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), provenance, map), featureIDMap, outputIDInfo, false, (Node) mkTrainingNode.convertTree());
    }

    protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> dataset, AbstractTrainingNode.LeafDeterminer leafDeterminer);

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseModel m0train(Dataset dataset, Map map, int i) {
        return train(dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseModel m1train(Dataset dataset, Map map) {
        return train(dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m3train(Dataset dataset, Map map, int i) {
        return train(dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m4train(Dataset dataset, Map map) {
        return train(dataset, (Map<String, Provenance>) map);
    }
}
