package org.jpmml.h2o;

import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;

/* loaded from: input_file:org/jpmml/h2o/SharedTreeMojoModelConverter.class */
public abstract class SharedTreeMojoModelConverter<M extends SharedTreeMojoModel> extends Converter<M> {
    private static final Field FIELD_COMPRESSEDTREES;
    private static final Field FIELD_NTREEGROUPS;
    private static final Field FIELD_NTREESPERGROUP;

    public SharedTreeMojoModelConverter(M m) {
        super(m);
    }

    public List<TreeModel> encodeTreeModels(Schema schema) {
        SharedTreeMojoModel sharedTreeMojoModel = (SharedTreeMojoModel) getModel();
        if (sharedTreeMojoModel._mojo_version < 1.2d) {
            throw new IllegalArgumentException("Version " + sharedTreeMojoModel._mojo_version + " is not supported");
        }
        byte[][] compressedTrees = getCompressedTrees(sharedTreeMojoModel);
        PredicateManager predicateManager = new PredicateManager();
        return (List) Stream.of((Object[]) compressedTrees).map(bArr -> {
            return encodeTreeModel(bArr, predicateManager, schema);
        }).collect(Collectors.toList());
    }

    public static TreeModel encodeTreeModel(byte[] bArr, PredicateManager predicateManager, Schema schema) {
        ContinuousLabel continuousLabel = new ContinuousLabel((FieldName) null, DataType.DOUBLE);
        return new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel), encodeNode(new ByteBufferWrapper(bArr), True.INSTANCE, bArr, new AtomicInteger(1), new CategoryManager(), predicateManager, schema)).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
    }

    public static Node encodeNode(ByteBufferWrapper byteBufferWrapper, Predicate predicate, byte[] bArr, AtomicInteger atomicInteger, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        Integer nextId = nextId(atomicInteger);
        int i = byteBufferWrapper.get1U();
        int i2 = i & 51;
        int i3 = (i & 192) >> 2;
        int i4 = i & 12;
        char c = byteBufferWrapper.get2();
        if (c == 65535) {
            return new LeafNode(Double.valueOf(byteBufferWrapper.get4f()), predicate).setId(nextId);
        }
        int i5 = byteBufferWrapper.get1U();
        boolean z = i5 == NaSplitDir.NAvsREST.value();
        boolean z2 = i5 == NaSplitDir.NALeft.value() || i5 == NaSplitDir.Left.value();
        CategoricalFeature feature = schema.getFeature(c);
        CategoryManager categoryManager2 = categoryManager;
        CategoryManager categoryManager3 = categoryManager;
        if (z) {
            createSimplePredicate = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_NOT_MISSING, (Object) null);
            createSimplePredicate2 = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, (Object) null);
        } else if (i4 != 0) {
            CategoricalFeature categoricalFeature = feature;
            GenmodelBitSet genmodelBitSet = new GenmodelBitSet(0);
            if (i4 == 8) {
                genmodelBitSet.fill2(bArr, byteBufferWrapper);
            } else {
                if (i4 != 12) {
                    throw new IllegalArgumentException("Node type " + i4 + " is not supported");
                }
                genmodelBitSet.fill3(bArr, byteBufferWrapper);
            }
            FieldName name = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (int i6 = 0; i6 < values.size(); i6++) {
                Object obj = values.get(i6);
                if (valueFilter.test(obj)) {
                    if (genmodelBitSet.contains(i6)) {
                        arrayList2.add(obj);
                    } else {
                        arrayList.add(obj);
                    }
                }
            }
            categoryManager2 = categoryManager2.fork(name, arrayList);
            categoryManager3 = categoryManager3.fork(name, arrayList2);
            createSimplePredicate = predicateManager.createPredicate(categoricalFeature, arrayList);
            createSimplePredicate2 = predicateManager.createPredicate(categoricalFeature, arrayList2);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Double valueOf = Double.valueOf(byteBufferWrapper.get4f());
            createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, valueOf);
            createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, valueOf);
        }
        ByteBufferWrapper byteBufferWrapper2 = new ByteBufferWrapper(bArr);
        byteBufferWrapper2.skip(byteBufferWrapper.position());
        if (i2 <= 3) {
            byteBufferWrapper2.skip(i2 + 1);
        }
        SimpleNode id = (i2 & 16) != 0 ? new LeafNode(Double.valueOf(byteBufferWrapper2.get4f()), createSimplePredicate).setId(nextId(atomicInteger)) : encodeNode(byteBufferWrapper2, createSimplePredicate, bArr, atomicInteger, categoryManager2, predicateManager, schema);
        ByteBufferWrapper byteBufferWrapper3 = new ByteBufferWrapper(bArr);
        byteBufferWrapper3.skip(byteBufferWrapper.position());
        switch (i2) {
            case 0:
                byteBufferWrapper3.skip(byteBufferWrapper3.get1U());
                break;
            case 1:
                byteBufferWrapper3.skip(byteBufferWrapper3.get2());
                break;
            case 2:
                byteBufferWrapper3.skip(byteBufferWrapper3.get3());
                break;
            case 3:
                byteBufferWrapper3.skip(byteBufferWrapper3.get4());
                break;
            case 48:
                byteBufferWrapper3.skip(4);
                break;
            default:
                throw new IllegalArgumentException("Node type " + i2 + " is not supported");
        }
        SimpleNode id2 = (i3 & 16) != 0 ? new LeafNode(Double.valueOf(byteBufferWrapper3.get4f()), createSimplePredicate2).setId(nextId(atomicInteger)) : encodeNode(byteBufferWrapper3, createSimplePredicate2, bArr, atomicInteger, categoryManager3, predicateManager, schema);
        return new BranchNode((Object) null, predicate).setId(nextId).setDefaultChild(z2 ? id.getId() : id2.getId()).addNodes(id, id2);
    }

    public static byte[][] getCompressedTrees(SharedTreeMojoModel sharedTreeMojoModel) {
        return (byte[][]) getFieldValue(FIELD_COMPRESSEDTREES, sharedTreeMojoModel);
    }

    public static int getNTreeGroups(SharedTreeMojoModel sharedTreeMojoModel) {
        return ((Integer) getFieldValue(FIELD_NTREEGROUPS, sharedTreeMojoModel)).intValue();
    }

    public static int getNTreesPerGroup(SharedTreeMojoModel sharedTreeMojoModel) {
        return ((Integer) getFieldValue(FIELD_NTREESPERGROUP, sharedTreeMojoModel)).intValue();
    }

    private static Integer nextId(AtomicInteger atomicInteger) {
        return Integer.valueOf(atomicInteger.getAndIncrement());
    }

    static {
        try {
            FIELD_COMPRESSEDTREES = SharedTreeMojoModel.class.getDeclaredField("_compressed_trees");
            FIELD_NTREEGROUPS = SharedTreeMojoModel.class.getDeclaredField("_ntree_groups");
            FIELD_NTREESPERGROUP = SharedTreeMojoModel.class.getDeclaredField("_ntrees_per_group");
        } catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    }
}
