package org.jpmml.sparkml;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;

/* loaded from: input_file:org/jpmml/sparkml/ModelConverter.class */
public abstract class ModelConverter<T extends Model<T> & HasFeaturesCol & HasPredictionCol> extends TransformerConverter<T> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.ModelConverter$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/ModelConverter$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public ModelConverter(T t) {
        super(t);
    }

    public abstract MiningFunction getMiningFunction();

    /* renamed from: encodeModel */
    public abstract org.dmg.pmml.Model mo12encodeModel(Schema schema);

    public Schema encodeSchema(SparkMLEncoder sparkMLEncoder) {
        int numFeatures;
        ClassificationModel classificationModel = (Model) getTransformer();
        CategoricalLabel categoricalLabel = null;
        if (hasLabelCol(classificationModel)) {
            String labelCol = ((HasLabelCol) classificationModel).getLabelCol();
            BooleanFeature onlyFeature = sparkMLEncoder.getOnlyFeature(labelCol);
            MiningFunction miningFunction = getMiningFunction();
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                case 1:
                    if (onlyFeature instanceof BooleanFeature) {
                        BooleanFeature booleanFeature = onlyFeature;
                        categoricalLabel = new CategoricalLabel(booleanFeature.getName(), booleanFeature.getDataType(), booleanFeature.getValues());
                        break;
                    } else if (onlyFeature instanceof CategoricalFeature) {
                        categoricalLabel = new CategoricalLabel(((CategoricalFeature) onlyFeature).getField());
                        break;
                    } else {
                        if (!(onlyFeature instanceof ContinuousFeature)) {
                            throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + onlyFeature);
                        }
                        ContinuousFeature continuousFeature = (ContinuousFeature) onlyFeature;
                        int i = 2;
                        if (classificationModel instanceof ClassificationModel) {
                            i = classificationModel.numClasses();
                        }
                        List createTargetCategories = LabelUtil.createTargetCategories(i);
                        Field categorical = sparkMLEncoder.toCategorical(continuousFeature.getName(), createTargetCategories);
                        sparkMLEncoder.putOnlyFeature(labelCol, new IndexFeature(sparkMLEncoder, categorical, createTargetCategories));
                        categoricalLabel = new CategoricalLabel(categorical.getName(), categorical.getDataType(), createTargetCategories);
                        break;
                    }
                case 2:
                    Field continuous = sparkMLEncoder.toContinuous(onlyFeature.getName());
                    continuous.setDataType(DataType.DOUBLE);
                    categoricalLabel = new ContinuousLabel(continuous.getName(), continuous.getDataType());
                    break;
                default:
                    throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
            }
        }
        if (classificationModel instanceof ClassificationModel) {
            SchemaUtil.checkSize(classificationModel.numClasses(), categoricalLabel);
        }
        List<Feature> features = sparkMLEncoder.getFeatures(((HasFeaturesCol) classificationModel).getFeaturesCol());
        if ((classificationModel instanceof PredictionModel) && (numFeatures = ((PredictionModel) classificationModel).numFeatures()) != -1) {
            SchemaUtil.checkSize(numFeatures, features);
        }
        Schema schema = new Schema(categoricalLabel, features);
        checkSchema(schema);
        return schema;
    }

    public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model model, SparkMLEncoder sparkMLEncoder) {
        return null;
    }

    public org.dmg.pmml.Model registerModel(SparkMLEncoder sparkMLEncoder) {
        Schema encodeSchema = encodeSchema(sparkMLEncoder);
        Label label = encodeSchema.getLabel();
        org.dmg.pmml.Model mo12encodeModel = mo12encodeModel(encodeSchema);
        List<OutputField> registerOutputFields = registerOutputFields(label, mo12encodeModel, sparkMLEncoder);
        if (registerOutputFields != null && registerOutputFields.size() > 0) {
            ModelUtil.ensureOutput(MiningModelUtil.getFinalModel(mo12encodeModel)).getOutputFields().addAll(registerOutputFields);
        }
        return mo12encodeModel;
    }

    public static boolean hasLabelCol(Model<?> model) {
        if (model instanceof HasLabelCol) {
            return model.isSet(((HasLabelCol) model).labelCol());
        }
        return false;
    }

    public static boolean hasPredictionCol(Model<?> model) {
        if (model instanceof HasPredictionCol) {
            return model.isSet(((HasPredictionCol) model).predictionCol());
        }
        return false;
    }

    public static boolean hasProbabilityCol(Model<?> model) {
        if (model instanceof HasProbabilityCol) {
            return model.isSet(((HasProbabilityCol) model).probabilityCol());
        }
        return false;
    }

    private static void checkSchema(Schema schema) {
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        if (label == null) {
            return;
        }
        Iterator it = features.iterator();
        while (it.hasNext()) {
            if (Objects.equals(label.getName(), ((Feature) it.next()).getName())) {
                throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns");
            }
        }
    }
}
