package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;

/* loaded from: input_file:org/jpmml/rexp/RangerConverter.class */
public class RangerConverter extends TreeModelConverter<RGenericVector> {
    boolean hasDependentVar;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/rexp/RangerConverter$ScoreEncoder.class */
    public interface ScoreEncoder {
        Node encode(Node node, Number number, RNumberVector<?> rNumberVector);
    }

    public RangerConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
        this.hasDependentVar = false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        DataField createDataField;
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RGenericVector genericElement = rGenericVector.getGenericElement("forest", false);
        if (genericElement == null) {
            throw new IllegalArgumentException("Missing 'forest' element. Please re-train the model object with 'write.forest' argument set to TRUE");
        }
        RStringVector stringElement = rGenericVector.getStringElement("treetype");
        RGenericVector genericElement2 = DecorationUtil.getGenericElement(rGenericVector, "variable.levels");
        FieldName create = FieldName.create("_target");
        String asScalar = stringElement.asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case -880190367:
                if (asScalar.equals("Regression")) {
                    z = false;
                    break;
                }
                break;
            case -619642874:
                if (asScalar.equals("Classification")) {
                    z = true;
                    break;
                }
                break;
            case 278872310:
                if (asScalar.equals("Probability estimation")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                createDataField = rExpEncoder.createDataField(create, OpType.CONTINUOUS, DataType.DOUBLE);
                break;
            case true:
            case true:
                createDataField = rExpEncoder.createDataField(create, OpType.CATEGORICAL, null, genericElement.getStringElement("levels").getValues());
                break;
            default:
                throw new IllegalArgumentException();
        }
        rExpEncoder.setLabel(createDataField);
        RBooleanVector booleanElement = genericElement.getBooleanElement("is.ordered");
        RStringVector stringElement2 = genericElement.getStringElement("independent.variable.names");
        this.hasDependentVar = booleanElement.size() == stringElement2.size() + 1;
        for (int i = 0; i < stringElement2.size(); i++) {
            if (!booleanElement.getValue(this.hasDependentVar ? i + 1 : i).booleanValue()) {
                throw new IllegalArgumentException();
            }
            String value = stringElement2.getValue(i);
            FieldName create2 = FieldName.create(value);
            rExpEncoder.addFeature((Field<?>) (genericElement2.hasElement(value) ? rExpEncoder.createDataField(create2, OpType.CATEGORICAL, DataType.STRING, genericElement2.getStringElement(value).getValues()) : rExpEncoder.createDataField(create2, OpType.CONTINUOUS, DataType.DOUBLE)));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo0encodeModel(Schema schema) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        String asScalar = rGenericVector.getStringElement("treetype").asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case -880190367:
                if (asScalar.equals("Regression")) {
                    z = false;
                    break;
                }
                break;
            case -619642874:
                if (asScalar.equals("Classification")) {
                    z = true;
                    break;
                }
                break;
            case 278872310:
                if (asScalar.equals("Probability estimation")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return encodeRegression(rGenericVector, schema);
            case true:
                return encodeClassification(rGenericVector, schema);
            case true:
                return encodeProbabilityForest(rGenericVector, schema);
            default:
                throw new IllegalArgumentException();
        }
    }

    private MiningModel encodeRegression(RGenericVector rGenericVector, Schema schema) {
        return new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, encodeForest(rGenericVector.getGenericElement("forest"), MiningFunction.REGRESSION, new ScoreEncoder() { // from class: org.jpmml.rexp.RangerConverter.1
            @Override // org.jpmml.rexp.RangerConverter.ScoreEncoder
            public Node encode(Node node, Number number, RNumberVector<?> rNumberVector) {
                node.setScore(number);
                return node;
            }
        }, schema)));
    }

    private MiningModel encodeClassification(RGenericVector rGenericVector, Schema schema) {
        RGenericVector genericElement = rGenericVector.getGenericElement("forest");
        final RStringVector stringElement = genericElement.getStringElement("levels");
        return new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, encodeForest(genericElement, MiningFunction.CLASSIFICATION, new ScoreEncoder() { // from class: org.jpmml.rexp.RangerConverter.2
            @Override // org.jpmml.rexp.RangerConverter.ScoreEncoder
            public Node encode(Node node, Number number, RNumberVector<?> rNumberVector) {
                int asInt = ValueUtil.asInt(number);
                if (rNumberVector != null) {
                    throw new IllegalArgumentException();
                }
                node.setScore(stringElement.getValue(asInt - 1));
                return node;
            }
        }, schema)));
    }

    private MiningModel encodeProbabilityForest(RGenericVector rGenericVector, Schema schema) {
        RGenericVector genericElement = rGenericVector.getGenericElement("forest");
        final RStringVector stringElement = genericElement.getStringElement("levels");
        CategoricalLabel label = schema.getLabel();
        return new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, encodeForest(genericElement, MiningFunction.CLASSIFICATION, new ScoreEncoder() { // from class: org.jpmml.rexp.RangerConverter.3
            /* JADX WARN: Multi-variable type inference failed */
            @Override // org.jpmml.rexp.RangerConverter.ScoreEncoder
            public Node encode(Node node, Number number, RNumberVector<?> rNumberVector) {
                if (number.doubleValue() != 0.0d || rNumberVector == null || rNumberVector.size() != stringElement.size()) {
                    throw new IllegalArgumentException();
                }
                ClassifierNode classifierNode = new ClassifierNode(node);
                List scoreDistributions = classifierNode.getScoreDistributions();
                Number number2 = null;
                for (int i = 0; i < rNumberVector.size(); i++) {
                    String value = stringElement.getValue(i);
                    Number number3 = (Number) rNumberVector.getValue(i);
                    if (number2 == null || ((Comparable) number2).compareTo(number3) < 0) {
                        classifierNode.setScore(value);
                        number2 = number3;
                    }
                    scoreDistributions.add(new ScoreDistribution(value, number3));
                }
                return classifierNode;
            }
        }, schema))).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, label));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<TreeModel> encodeForest(RGenericVector rGenericVector, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
        RNumberVector<?> numericElement = rGenericVector.getNumericElement("num.trees");
        RGenericVector genericElement = rGenericVector.getGenericElement("child.nodeIDs");
        RGenericVector genericElement2 = rGenericVector.getGenericElement("split.varIDs");
        RGenericVector genericElement3 = rGenericVector.getGenericElement("split.values");
        RGenericVector genericElement4 = rGenericVector.getGenericElement("terminal.class.counts", false);
        Schema anonymousSchema = schema.toAnonymousSchema();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < ValueUtil.asInt((Number) numericElement.asScalar()); i++) {
            arrayList.add(encodeTreeModel(miningFunction, scoreEncoder, (RGenericVector) genericElement.getValue(i), (RNumberVector) genericElement2.getValue(i), (RNumberVector) genericElement3.getValue(i), genericElement4 != null ? (RGenericVector) genericElement4.getValue(i) : null, anonymousSchema));
        }
        return arrayList;
    }

    private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector rGenericVector, RNumberVector<?> rNumberVector, RNumberVector<?> rNumberVector2, RGenericVector rGenericVector2, Schema schema) {
        return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(True.INSTANCE, 0, scoreEncoder, (RNumberVector) rGenericVector.getValue(0), (RNumberVector) rGenericVector.getValue(1), rNumberVector, rNumberVector2, rGenericVector2, new CategoryManager(), schema)).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Node encodeNode(Predicate predicate, int i, ScoreEncoder scoreEncoder, RNumberVector<?> rNumberVector, RNumberVector<?> rNumberVector2, RNumberVector<?> rNumberVector3, RNumberVector<?> rNumberVector4, RGenericVector rGenericVector, CategoryManager categoryManager, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        int asInt = ValueUtil.asInt((Number) rNumberVector.getValue(i));
        int asInt2 = ValueUtil.asInt((Number) rNumberVector2.getValue(i));
        Number number = (Number) rNumberVector4.getValue(i);
        RNumberVector<?> rNumberVector5 = rGenericVector != null ? (RNumberVector) rGenericVector.getValue(i) : null;
        if (asInt == 0 && asInt2 == 0) {
            return scoreEncoder.encode(new LeafNode((Object) null, predicate), number, rNumberVector5);
        }
        CategoryManager categoryManager2 = categoryManager;
        CategoryManager categoryManager3 = categoryManager;
        int asInt3 = ValueUtil.asInt((Number) rNumberVector3.getValue(i));
        CategoricalFeature feature = schema.getFeature(this.hasDependentVar ? asInt3 - 1 : asInt3);
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = feature;
            int asInt4 = ValueUtil.asInt(Double.valueOf(Math.floor(number.doubleValue())));
            FieldName name = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
            List<?> filterValues = filterValues(values.subList(0, asInt4), valueFilter);
            List<?> filterValues2 = filterValues(values.subList(asInt4, values.size()), valueFilter);
            categoryManager2 = categoryManager2.fork(name, filterValues);
            categoryManager3 = categoryManager3.fork(name, filterValues2);
            createSimplePredicate = createSimpleSetPredicate(categoricalFeature, filterValues);
            createSimplePredicate2 = createSimpleSetPredicate(categoricalFeature, filterValues2);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            createSimplePredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, number);
            createSimplePredicate2 = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, number);
        }
        return new BranchNode((Object) null, predicate).addNodes(encodeNode(createSimplePredicate, asInt, scoreEncoder, rNumberVector, rNumberVector2, rNumberVector3, rNumberVector4, rGenericVector, categoryManager2, schema), encodeNode(createSimplePredicate2, asInt2, scoreEncoder, rNumberVector, rNumberVector2, rNumberVector3, rNumberVector4, rGenericVector, categoryManager3, schema));
    }

    private static List<Object> filterValues(List<?> list, java.util.function.Predicate<Object> predicate) {
        return (List) list.stream().filter(predicate).collect(Collectors.toList());
    }
}
