package sklearn.ensemble.gradient_boosting;

import java.util.ArrayList;
import java.util.List;
import java.util.function.IntFunction;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.HasPriorProbability;
import sklearn.SkLearnClassifier;
import sklearn.VersionUtil;
import sklearn.tree.HasTreeOptions;
import sklearn.tree.TreeRegressor;
import sklearn2pmml.EstimatorProxy;

/* loaded from: input_file:sklearn/ensemble/gradient_boosting/GradientBoostingClassifier.class */
public class GradientBoostingClassifier extends SkLearnClassifier implements HasEstimatorEnsemble<TreeRegressor>, HasTreeOptions {

    /* loaded from: input_file:sklearn/ensemble/gradient_boosting/GradientBoostingClassifier$GradientBoostingClassifierProxy.class */
    private abstract class GradientBoostingClassifierProxy extends EstimatorProxy implements HasEstimatorEnsemble<TreeRegressor>, HasTreeOptions {
        private GradientBoostingClassifierProxy() {
        }

        @Override // sklearn2pmml.EstimatorProxy, sklearn.HasEstimator
        public Estimator getEstimator() {
            return GradientBoostingClassifier.this;
        }
    }

    public GradientBoostingClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return containsKey("n_features") ? getInteger("n_features").intValue() : super.getNumberOfFeatures();
    }

    @Override // sklearn.Estimator, sklearn.HasType
    public DataType getDataType() {
        return DataType.FLOAT;
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo7encodeModel(Schema schema) {
        MiningModel createClassification;
        String skLearnVersion = getSkLearnVersion();
        LossFunction loss = getLoss();
        int intValue = loss.getK().intValue();
        HasPriorProbability init = getInit();
        Number learningRate = getLearningRate();
        init.getClass();
        IntFunction intFunction = init::getPriorProbability;
        if (skLearnVersion != null && VersionUtil.compareVersion(skLearnVersion, "0.21") >= 0) {
            List<? extends Number> computeInitialPredictions = loss.computeInitialPredictions(init);
            computeInitialPredictions.getClass();
            intFunction = computeInitialPredictions::get;
        }
        Schema anonymousRegressorSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel label = schema.getLabel();
        if (intValue == 1) {
            SchemaUtil.checkSize(2, label);
            createClassification = MiningModelUtil.createBinaryLogisticClassification(GradientBoostingUtil.encodeGradientBoosting(this, (Number) intFunction.apply(1), learningRate, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create(Estimator.FIELD_DECISION_FUNCTION, new Object[]{label.getValue(1)}), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[]{loss.mo20createTransformation()})), 1.0d, 0.0d, RegressionModel.NormalizationMethod.NONE, false, schema);
        } else {
            if (intValue < 3) {
                throw new IllegalArgumentException();
            }
            SchemaUtil.checkSize(intValue, label);
            List<? extends TreeRegressor> estimators = getEstimators();
            ArrayList arrayList = new ArrayList();
            int size = label.size();
            int size2 = estimators.size() / size;
            for (int i = 0; i < size; i++) {
                final List column = CMatrixUtil.getColumn(estimators, size2, size, i);
                arrayList.add(GradientBoostingUtil.encodeGradientBoosting(new GradientBoostingClassifierProxy() { // from class: sklearn.ensemble.gradient_boosting.GradientBoostingClassifier.1
                    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                    {
                        super();
                    }

                    @Override // sklearn.HasEstimatorEnsemble
                    public List<? extends TreeRegressor> getEstimators() {
                        return column;
                    }
                }, (Number) intFunction.apply(i), learningRate, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create(Estimator.FIELD_DECISION_FUNCTION, new Object[]{label.getValue(i)}), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[]{loss.mo20createTransformation()})));
            }
            createClassification = MiningModelUtil.createClassification(arrayList, RegressionModel.NormalizationMethod.SIMPLEMAX, false, schema);
        }
        encodePredictProbaOutput(createClassification, DataType.DOUBLE, label);
        return createClassification;
    }

    public LossFunction getLoss() {
        return containsKey("loss_") ? (LossFunction) get("loss_", LossFunction.class) : (LossFunction) get("_loss", LossFunction.class);
    }

    public HasPriorProbability getInit() {
        return (HasPriorProbability) get("init_", HasPriorProbability.class);
    }

    public Number getLearningRate() {
        return getNumber("learning_rate");
    }

    @Override // sklearn.HasEstimatorEnsemble
    public List<? extends TreeRegressor> getEstimators() {
        return getArray("estimators_", TreeRegressor.class);
    }
}
