package sklearn;

import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.jpmml.converter.Feature;
import org.jpmml.converter.HasNativeConfiguration;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.EncodableUtil;
import org.jpmml.sklearn.HasSkLearnOptions;
import org.jpmml.sklearn.SkLearnEncoder;

/* loaded from: input_file:sklearn/EstimatorUtil.class */
public class EstimatorUtil {
    private EstimatorUtil() {
    }

    public static MiningFunction getMiningFunction(List<? extends Estimator> list) {
        Set set = (Set) list.stream().map(estimator -> {
            return estimator.getMiningFunction();
        }).collect(Collectors.toSet());
        return set.size() == 1 ? (MiningFunction) Iterables.getOnlyElement(set) : MiningFunction.MIXED;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static List<?> getClasses(Estimator estimator) {
        if (estimator instanceof HasClasses) {
            return ((HasClasses) estimator).getClasses();
        }
        throw new IllegalArgumentException("The estimator object (" + ClassDictUtil.formatClass(estimator) + ") is not a classifier");
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static boolean hasProbabilityDistribution(Estimator estimator) {
        if (estimator instanceof HasClasses) {
            return ((HasClasses) estimator).hasProbabilityDistribution();
        }
        throw new IllegalArgumentException("The estimator object (" + ClassDictUtil.formatClass(estimator) + ") is not a classifier");
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static List<Feature> export(Estimator estimator, String str, Schema schema, Model model, SkLearnEncoder skLearnEncoder) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1907484474:
                if (str.equals(SkLearnMethods.PREDICT_PROBA)) {
                    z = 3;
                    break;
                }
                break;
            case -318720807:
                if (str.equals("predict")) {
                    z = 2;
                    break;
                }
                break;
            case 93029230:
                if (str.equals("apply")) {
                    z = false;
                    break;
                }
                break;
            case 266243291:
                if (str.equals(SkLearnMethods.DECISION_FUNCTION)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (estimator instanceof HasApplyField) {
                    return skLearnEncoder.export(model, ((HasApplyField) estimator).getApplyField());
                }
                if (estimator instanceof HasMultiApplyField) {
                    return skLearnEncoder.export(model, ((HasMultiApplyField) estimator).getMultiApplyFields());
                }
                throw new IllegalArgumentException();
            case true:
                if (estimator instanceof HasDecisionFunctionField) {
                    return skLearnEncoder.export(model, ((HasDecisionFunctionField) estimator).getDecisionFunctionField());
                }
                throw new IllegalArgumentException();
            case true:
                if (estimator instanceof HasPredictField) {
                    return skLearnEncoder.export(model, ((HasPredictField) estimator).getPredictField());
                }
                if (estimator.isSupervised()) {
                    return Collections.singletonList(skLearnEncoder.exportPrediction(model, (ScalarLabel) schema.getLabel()));
                }
                Output output = model.getOutput();
                if (output == null || !output.hasOutputFields()) {
                    throw new IllegalArgumentException();
                }
                List list = (List) output.getOutputFields().stream().filter(outputField -> {
                    return SkLearnEncoder.isPrediction(outputField);
                }).collect(Collectors.toList());
                if (list.isEmpty()) {
                    throw new IllegalArgumentException();
                }
                return skLearnEncoder.export(model, ((OutputField) Iterables.getLast(list)).getName());
            case true:
                if (!(estimator instanceof HasClasses)) {
                    throw new IllegalArgumentException();
                }
                return skLearnEncoder.export(model, (List<String>) estimator.createPredictProbaFields(DataType.DOUBLE, schema.getLabel()).stream().map(outputField2 -> {
                    return outputField2.requireName();
                }).collect(Collectors.toList()));
            default:
                throw new IllegalArgumentException(str);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <E extends Estimator & HasFeatureNamesIn & HasSkLearnOptions> PMML encodePMML(E e) {
        SkLearnEncoder skLearnEncoder = new SkLearnEncoder();
        if (e.isSupervised()) {
            skLearnEncoder.initLabel(e, EncodableUtil.generateOutputNames(e));
        }
        skLearnEncoder.initFeatures(e, EncodableUtil.getOrGenerateFeatureNames(e));
        Model encodeNativeLike = encodeNativeLike(e, skLearnEncoder.createSchema());
        skLearnEncoder.setModel(encodeNativeLike);
        return skLearnEncoder.encodePMML(encodeNativeLike);
    }

    public static Model encodeNativeLike(Estimator estimator, Schema schema) {
        if (!(estimator instanceof HasNativeConfiguration)) {
            return estimator.encode(schema);
        }
        HasNativeConfiguration hasNativeConfiguration = (HasNativeConfiguration) estimator;
        Map<String, ?> pMMLOptions = estimator.getPMMLOptions();
        try {
            estimator.setPMMLOptions(hasNativeConfiguration.getNativeConfiguration());
            Model encode = estimator.encode(schema);
            estimator.setPMMLOptions(pMMLOptions);
            return encode;
        } catch (Throwable th) {
            estimator.setPMMLOptions(pMMLOptions);
            throw th;
        }
    }
}
