package org.jpmml.h2o;

import hex.genmodel.MojoModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;

/* loaded from: input_file:org/jpmml/h2o/GlmMojoModelBaseConverter.class */
public abstract class GlmMojoModelBaseConverter<M extends MojoModel> extends Converter<M> {
    private static final Class<?> CLASS_GLMMOJOMODELBASE;
    private static final Field FIELD_BETA;
    private static final Field FIELD_CATS;
    private static final Field FIELD_CATMODES;
    private static final Field FIELD_CATOFFSETS;
    private static final Field FIELD_FAMILY;
    private static final Field FIELD_MEANIMPUTATION;
    private static final Field FIELD_NUMS;
    private static final Field FIELD_NUMMEANS;
    private static final Field FIELD_USEALLFACTORLEVELS;

    public GlmMojoModelBaseConverter(M m) {
        super(m);
    }

    @Override // org.jpmml.h2o.Converter
    public Schema toMojoModelSchema(Schema schema) {
        M model = getModel();
        int cats = getCats(model);
        int[] catOffsets = getCatOffsets(model);
        int nums = getNums(model);
        boolean meanImputation = getMeanImputation(model);
        final boolean useAllFactorLevels = getUseAllFactorLevels(model);
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        List list = (List) features.stream().filter(feature -> {
            return feature instanceof CategoricalFeature;
        }).collect(Collectors.toList());
        SchemaUtil.checkSize(cats, list);
        for (int i = 0; i < cats; i++) {
            SchemaUtil.checkSize((catOffsets[i + 1] - catOffsets[i]) + (useAllFactorLevels ? 0 : 1), (CategoricalFeature) list.get(i));
        }
        List list2 = (List) features.stream().filter(feature2 -> {
            return !(feature2 instanceof CategoricalFeature);
        }).map(feature3 -> {
            return feature3.toContinuousFeature();
        }).collect(Collectors.toList());
        SchemaUtil.checkSize(nums, list2);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(list);
        arrayList.addAll(list2);
        if (meanImputation) {
            int[] catModes = getCatModes(model);
            double[] numMeans = getNumMeans(model);
            if (catModes.length != cats) {
                throw new IllegalArgumentException("Expected " + cats + " mode values, got " + catModes.length + " mode values");
            }
            if (numMeans.length != nums) {
                throw new IllegalArgumentException("Expected " + nums + " mean values, got " + numMeans.length + " mean values");
            }
            for (int i2 = 0; i2 < cats; i2++) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) list.get(i2);
                ImputerUtil.encodeFeature(categoricalFeature, categoricalFeature.getValues().get(catModes[i2]), MissingValueTreatmentMethod.AS_MODE);
            }
            for (int i3 = 0; i3 < nums; i3++) {
                ImputerUtil.encodeFeature((ContinuousFeature) list2.get(i3), Double.valueOf(numMeans[i3]), MissingValueTreatmentMethod.AS_MEAN);
            }
        }
        return new Schema(label, (List) arrayList.stream().flatMap(new Function<Feature, Stream<Feature>>() { // from class: org.jpmml.h2o.GlmMojoModelBaseConverter.1
            @Override // java.util.function.Function
            public Stream<Feature> apply(Feature feature4) {
                PMMLEncoder encoder = feature4.getEncoder();
                if (!(feature4 instanceof CategoricalFeature)) {
                    return Stream.of(feature4);
                }
                CategoricalFeature categoricalFeature2 = (CategoricalFeature) feature4;
                List values = categoricalFeature2.getValues();
                if (!useAllFactorLevels) {
                    values = values.subList(1, values.size());
                }
                return values.stream().map(obj -> {
                    return new BinaryFeature(encoder, categoricalFeature2.getName(), categoricalFeature2.getDataType(), obj);
                });
            }
        }).collect(Collectors.toList()));
    }

    public static double[] getBeta(MojoModel mojoModel) {
        return (double[]) getFieldValue(FIELD_BETA, mojoModel);
    }

    public static int getCats(MojoModel mojoModel) {
        return ((Integer) getFieldValue(FIELD_CATS, mojoModel)).intValue();
    }

    public static int[] getCatModes(MojoModel mojoModel) {
        return (int[]) getFieldValue(FIELD_CATMODES, mojoModel);
    }

    public static int[] getCatOffsets(MojoModel mojoModel) {
        return (int[]) getFieldValue(FIELD_CATOFFSETS, mojoModel);
    }

    public static String getFamily(MojoModel mojoModel) {
        return (String) getFieldValue(FIELD_FAMILY, mojoModel);
    }

    public static boolean getMeanImputation(MojoModel mojoModel) {
        return ((Boolean) getFieldValue(FIELD_MEANIMPUTATION, mojoModel)).booleanValue();
    }

    public static int getNums(MojoModel mojoModel) {
        return ((Integer) getFieldValue(FIELD_NUMS, mojoModel)).intValue();
    }

    public static double[] getNumMeans(MojoModel mojoModel) {
        return (double[]) getFieldValue(FIELD_NUMMEANS, mojoModel);
    }

    public static boolean getUseAllFactorLevels(MojoModel mojoModel) {
        return ((Boolean) getFieldValue(FIELD_USEALLFACTORLEVELS, mojoModel)).booleanValue();
    }

    static {
        try {
            CLASS_GLMMOJOMODELBASE = Class.forName("hex.genmodel.algos.glm.GlmMojoModelBase");
            try {
                FIELD_BETA = CLASS_GLMMOJOMODELBASE.getDeclaredField("_beta");
                FIELD_CATS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_cats");
                FIELD_CATMODES = CLASS_GLMMOJOMODELBASE.getDeclaredField("_catModes");
                FIELD_CATOFFSETS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_catOffsets");
                FIELD_FAMILY = CLASS_GLMMOJOMODELBASE.getDeclaredField("_family");
                FIELD_MEANIMPUTATION = CLASS_GLMMOJOMODELBASE.getDeclaredField("_meanImputation");
                FIELD_NUMS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_nums");
                FIELD_NUMMEANS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_numMeans");
                FIELD_USEALLFACTORLEVELS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_useAllFactorLevels");
            } catch (ReflectiveOperationException e) {
                throw new RuntimeException(e);
            }
        } catch (ReflectiveOperationException e2) {
            throw new RuntimeException(e2);
        }
    }
}
