package sklearn;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Value;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.python.CastFunction;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;

/* loaded from: input_file:sklearn/Classifier.class */
public abstract class Classifier extends Estimator implements HasClasses {
    public static final String FIELD_PROBABILITY = "probability";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: sklearn.Classifier$3, reason: invalid class name */
    /* loaded from: input_file:sklearn/Classifier$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$OpType = new int[OpType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CATEGORICAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.ORDINAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public Classifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLASSIFICATION;
    }

    @Override // sklearn.Estimator
    public boolean isSupervised() {
        return true;
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfOutputs
    public int getNumberOfOutputs() {
        int numberOfOutputs = super.getNumberOfOutputs();
        if (numberOfOutputs == -1) {
            numberOfOutputs = 1;
        }
        return numberOfOutputs;
    }

    @Override // sklearn.HasClasses
    public List<?> getClasses() {
        return canonicalizeValues((List) getListLike(SkLearnFields.CLASSES).stream().map(obj -> {
            return obj instanceof HasArray ? canonicalizeValues(((HasArray) obj).getArrayContent()) : obj;
        }).collect(Collectors.toList()));
    }

    @Override // sklearn.HasClasses
    public boolean hasProbabilityDistribution() {
        return true;
    }

    @Override // sklearn.Estimator
    public Label encodeLabel(List<String> list, SkLearnEncoder skLearnEncoder) {
        List<?> classes = getClasses();
        if (list.size() == 1) {
            return encodeLabel(list.get(0), classes, skLearnEncoder);
        }
        if (list.size() < 2) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            final String str = list.get(i);
            arrayList.add(encodeLabel(str, (List) new CastFunction<List<?>>(List.class) { // from class: sklearn.Classifier.1
                public String formatMessage(Object obj) {
                    return "The categories object of the " + (str != null ? "'" + str + "' " : "<un-named> ") + " target field (" + ClassDictUtil.formatClass(obj) + ") is not supported";
                }
            }.apply(classes.get(i)), skLearnEncoder));
        }
        return new MultiLabel(arrayList);
    }

    protected DiscreteLabel encodeLabel(String str, List<?> list, SkLearnEncoder skLearnEncoder) {
        return encodeLabel(str, OpType.CATEGORICAL, list, skLearnEncoder);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DiscreteLabel encodeLabel(String str, OpType opType, List<?> list, SkLearnEncoder skLearnEncoder) {
        DataType dataType = TypeUtil.getDataType(list, DataType.STRING);
        if (str == null) {
            switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$OpType[opType.ordinal()]) {
                case 1:
                    return new CategoricalLabel(dataType, list);
                case 2:
                    return new OrdinalLabel(dataType, list);
                default:
                    throw new IllegalArgumentException();
            }
        }
        DataField createDataField = skLearnEncoder.createDataField(str, opType, dataType, list);
        Map<String, Map<String, ?>> map = (Map) getOption(HasClassifierOptions.OPTION_CLASS_EXTENSIONS, null);
        if (map != null) {
            addClassExtensions(createDataField, map);
        }
        switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$OpType[opType.ordinal()]) {
            case 1:
                return new CategoricalLabel(createDataField);
            case 2:
                return new OrdinalLabel(createDataField);
            default:
                throw new IllegalArgumentException();
        }
    }

    private void addClassExtensions(DataField dataField, Map<String, Map<String, ?>> map) {
        ArrayList arrayList = new ArrayList();
        if (map != null) {
            for (Map.Entry<String, Map<String, ?>> entry : map.entrySet()) {
                String key = entry.getKey();
                final Map<String, ?> value = entry.getValue();
                arrayList.add(new AbstractExtender(key) { // from class: sklearn.Classifier.2
                    public VisitorAction visit(Value value2) {
                        Object obj = value.get(value2.requireValue());
                        if (obj != null) {
                            addExtension(value2, ValueUtil.asString(ScalarUtil.decode(obj)));
                        }
                        return super.visit(value2);
                    }
                });
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((Visitor) it.next()).applyTo(dataField);
        }
    }

    public List<OutputField> encodePredictProbaOutput(Model model, DataType dataType, DiscreteLabel discreteLabel) {
        List<OutputField> createPredictProbaFields = createPredictProbaFields(dataType, discreteLabel);
        ModelUtil.ensureOutput(MiningModelUtil.getFinalModel(model)).getOutputFields().addAll(createPredictProbaFields);
        return createPredictProbaFields;
    }

    private static List<?> canonicalizeValues(List<?> list) {
        return (List) list.stream().map(obj -> {
            return obj instanceof Long ? Integer.valueOf(Math.toIntExact(((Long) obj).longValue())) : obj;
        }).collect(Collectors.toList());
    }
}
