package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FlagManager;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;

/* loaded from: input_file:org/jpmml/rexp/GBMConverter.class */
public class GBMConverter extends TreeModelConverter<RGenericVector> {
    private static final List<Integer> BINARY_CLASSES = Arrays.asList(0, 1);

    public GBMConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        DataField createDataField;
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RGenericVector genericElement = rGenericVector.getGenericElement("distribution");
        RStringVector stringElement = rGenericVector.getStringElement("response.name", false);
        RGenericVector genericElement2 = rGenericVector.getGenericElement("var.levels");
        RStringVector stringElement2 = rGenericVector.getStringElement("var.names");
        RNumberVector<?> numericElement = rGenericVector.getNumericElement("var.type");
        RStringVector stringElement3 = rGenericVector.getStringElement("classes", false);
        RStringVector stringElement4 = genericElement.getStringElement("name");
        FieldName create = stringElement != null ? FieldName.create(stringElement.asScalar()) : FieldName.create("y");
        String asScalar = stringElement4.asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case -1526272517:
                if (asScalar.equals("gaussian")) {
                    z = false;
                    break;
                }
                break;
            case -1319084603:
                if (asScalar.equals("adaboost")) {
                    z = true;
                    break;
                }
                break;
            case 244821220:
                if (asScalar.equals("bernoulli")) {
                    z = 2;
                    break;
                }
                break;
            case 508210817:
                if (asScalar.equals("multinomial")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                createDataField = rExpEncoder.createDataField(create, OpType.CONTINUOUS, DataType.DOUBLE);
                break;
            case true:
            case true:
                createDataField = rExpEncoder.createDataField(create, OpType.CATEGORICAL, DataType.INTEGER, BINARY_CLASSES);
                break;
            case true:
                createDataField = rExpEncoder.createDataField(create, OpType.CATEGORICAL, DataType.STRING, stringElement3.getValues());
                break;
            default:
                throw new IllegalArgumentException();
        }
        rExpEncoder.setLabel(createDataField);
        for (int i = 0; i < stringElement2.size(); i++) {
            FieldName create2 = FieldName.create(stringElement2.getValue(i));
            rExpEncoder.addFeature((Field<?>) (ValueUtil.asInt((Number) numericElement.getValue(i)) > 0 ? rExpEncoder.createDataField(create2, OpType.CATEGORICAL, DataType.STRING, ((RStringVector) genericElement2.getValue(i)).getValues()) : rExpEncoder.createDataField(create2, OpType.CONTINUOUS, DataType.DOUBLE)));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo0encodeModel(Schema schema) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RDoubleVector doubleElement = rGenericVector.getDoubleElement("initF");
        RGenericVector genericElement = rGenericVector.getGenericElement("trees");
        RGenericVector genericElement2 = rGenericVector.getGenericElement("c.splits");
        RStringVector stringElement = rGenericVector.getGenericElement("distribution").getStringElement("name");
        Schema anonymousRegressorSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < genericElement.size(); i++) {
            arrayList.add(encodeTreeModel(MiningFunction.REGRESSION, (RGenericVector) genericElement.getValue(i), genericElement2, anonymousRegressorSchema));
        }
        return encodeMiningModel(stringElement, arrayList, (Double) doubleElement.asScalar(), schema);
    }

    private MiningModel encodeMiningModel(RStringVector rStringVector, List<TreeModel> list, Double d, Schema schema) {
        String asScalar = rStringVector.asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case -1526272517:
                if (asScalar.equals("gaussian")) {
                    z = false;
                    break;
                }
                break;
            case -1319084603:
                if (asScalar.equals("adaboost")) {
                    z = true;
                    break;
                }
                break;
            case 244821220:
                if (asScalar.equals("bernoulli")) {
                    z = 2;
                    break;
                }
                break;
            case 508210817:
                if (asScalar.equals("multinomial")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return encodeRegression(list, d, schema);
            case true:
                return encodeBinaryClassification(list, d, -2.0d, schema);
            case true:
                return encodeBinaryClassification(list, d, -1.0d, schema);
            case true:
                return encodeMultinomialClassification(list, d, schema);
            default:
                throw new IllegalArgumentException();
        }
    }

    private MiningModel encodeRegression(List<TreeModel> list, Double d, Schema schema) {
        return createMiningModel(list, d, schema);
    }

    private MiningModel encodeBinaryClassification(List<TreeModel> list, Double d, double d2, Schema schema) {
        return MiningModelUtil.createBinaryLogisticClassification(createMiningModel(list, d, schema.toAnonymousRegressorSchema(DataType.DOUBLE)).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0])), -d2, 0.0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
    }

    private MiningModel encodeMultinomialClassification(List<TreeModel> list, Double d, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        Schema anonymousRegressorSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        ArrayList arrayList = new ArrayList();
        int size = label.size();
        int size2 = list.size() / size;
        for (int i = 0; i < size; i++) {
            arrayList.add(createMiningModel(CMatrixUtil.getColumn(list, size2, size, i), d, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue(" + label.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0])));
        }
        return MiningModelUtil.createClassification(arrayList, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
    }

    private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector rGenericVector, RGenericVector rGenericVector2, Schema schema) {
        return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(True.INSTANCE, 0, rGenericVector, rGenericVector2, new FlagManager(), new CategoryManager(), schema)).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
    }

    private Node encodeNode(Predicate predicate, int i, RGenericVector rGenericVector, RGenericVector rGenericVector2, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        RIntegerVector rIntegerVector = (RIntegerVector) rGenericVector.getValue(0);
        RDoubleVector rDoubleVector = (RDoubleVector) rGenericVector.getValue(1);
        RIntegerVector rIntegerVector2 = (RIntegerVector) rGenericVector.getValue(2);
        RIntegerVector rIntegerVector3 = (RIntegerVector) rGenericVector.getValue(3);
        RIntegerVector rIntegerVector4 = (RIntegerVector) rGenericVector.getValue(4);
        RDoubleVector rDoubleVector2 = (RDoubleVector) rGenericVector.getValue(7);
        Integer valueOf = Integer.valueOf(i + 1);
        Integer value = rIntegerVector.getValue(i);
        if (value.intValue() == -1) {
            return new LeafNode(rDoubleVector2.getValue(i), predicate).setId(valueOf);
        }
        FlagManager flagManager2 = flagManager;
        FlagManager flagManager3 = flagManager;
        Feature feature = schema.getFeature(value.intValue());
        FieldName name = feature.getName();
        Boolean bool = (Boolean) flagManager.getValue(name);
        if (bool == null) {
            flagManager2 = flagManager2.fork(name, Boolean.TRUE);
            flagManager3 = flagManager3.fork(name, Boolean.FALSE);
        }
        Predicate createSimplePredicate3 = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        CategoryManager categoryManager2 = categoryManager;
        CategoryManager categoryManager3 = categoryManager;
        Double value2 = rDoubleVector.getValue(i);
        if (feature instanceof CategoricalFeature) {
            Feature feature2 = (CategoricalFeature) feature;
            FieldName name2 = feature2.getName();
            List values = feature2.getValues();
            List<Integer> values2 = ((RIntegerVector) rGenericVector2.getValue(ValueUtil.asInt(value2))).getValues();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name2);
            List<?> selectValues = selectValues(values, valueFilter, values2, true);
            List<?> selectValues2 = selectValues(values, valueFilter, values2, false);
            categoryManager2 = categoryManager2.fork(name2, selectValues);
            categoryManager3 = categoryManager3.fork(name2, selectValues2);
            createSimplePredicate = createSimpleSetPredicate(feature2, selectValues);
            createSimplePredicate2 = createSimpleSetPredicate(feature2, selectValues2);
        } else {
            Feature continuousFeature = feature.toContinuousFeature();
            createSimplePredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value2);
            createSimplePredicate2 = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value2);
        }
        SimpleNode id = new BranchNode((Object) null, predicate).setId(valueOf);
        List nodes = id.getNodes();
        Integer value3 = rIntegerVector4.getValue(i);
        if (value3.intValue() != -1 && (bool == null || bool.booleanValue())) {
            nodes.add(encodeNode(createSimplePredicate3, value3.intValue(), rGenericVector, rGenericVector2, flagManager2, categoryManager, schema));
        }
        Integer value4 = rIntegerVector2.getValue(i);
        if (value4.intValue() != -1 && (bool == null || !bool.booleanValue())) {
            nodes.add(encodeNode(createSimplePredicate, value4.intValue(), rGenericVector, rGenericVector2, flagManager3, categoryManager2, schema));
        }
        Integer value5 = rIntegerVector3.getValue(i);
        if (value5.intValue() != -1 && (bool == null || !bool.booleanValue())) {
            nodes.add(encodeNode(createSimplePredicate2, value5.intValue(), rGenericVector, rGenericVector2, flagManager3, categoryManager3, schema));
        }
        return id;
    }

    private static MiningModel createMiningModel(List<TreeModel> list, Double d, Schema schema) {
        ContinuousLabel label = schema.getLabel();
        return new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, list)).setTargets(ModelUtil.createRescaleTargets((Number) null, d, label));
    }

    private static List<Object> selectValues(List<?> list, java.util.function.Predicate<Object> predicate, List<Integer> list2, boolean z) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Object obj = list.get(i);
            Integer num = list2.get(i);
            if ((z ? num.intValue() == -1 : num.intValue() == 1) && predicate.test(obj)) {
                arrayList.add(obj);
            }
        }
        return arrayList;
    }
}
