package sklearn.linear_model.logistic;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.ClassDictUtil;
import sklearn.Estimator;
import sklearn.VersionUtil;
import sklearn.linear_model.LinearClassifier;

/* loaded from: input_file:sklearn/linear_model/logistic/LogisticRegression.class */
public class LogisticRegression extends LinearClassifier {
    public LogisticRegression(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.linear_model.LinearClassifier, sklearn.Estimator
    /* renamed from: encodeModel */
    public Model mo7encodeModel(Schema schema) {
        String skLearnVersion = getSkLearnVersion();
        String multiClass = getMultiClass();
        String solver = getSolver();
        if ("auto".equals(multiClass) && skLearnVersion != null && VersionUtil.compareVersion(skLearnVersion, "0.22") >= 0) {
            multiClass = getAutoMultiClass(solver, getCoefShape());
        }
        if ("auto".equals(multiClass)) {
            throw new IllegalArgumentException("Attribute '" + ClassDictUtil.formatMember(this, "multi_class") + "' must be explicitly set to the 'ovr' or 'multinomial' value");
        }
        if ("multinomial".equals(multiClass)) {
            return encodeMultinomialModel(schema);
        }
        if ("ovr".equals(multiClass)) {
            return encodeOvRModel(schema);
        }
        throw new IllegalArgumentException(multiClass);
    }

    private Model encodeMultinomialModel(Schema schema) {
        String skLearnVersion = getSkLearnVersion();
        int[] coefShape = getCoefShape();
        int i = coefShape[0];
        int i2 = coefShape[1];
        List<? extends Number> coef = getCoef();
        List<? extends Number> intercept = getIntercept();
        PMMLEncoder encoder = schema.getEncoder();
        CategoricalLabel label = schema.getLabel();
        List features = schema.getFeatures();
        if (i != 1) {
            if (i < 3) {
                throw new IllegalArgumentException();
            }
            SchemaUtil.checkSize(i, label);
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < label.size(); i3++) {
                arrayList.add(RegressionModelUtil.createRegressionTable(features, CMatrixUtil.getRow(coef, i, i2, i3), intercept.get(i3)).setTargetCategory(label.getValue(i3)));
            }
            RegressionModel normalizationMethod = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), arrayList).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
            encodePredictProbaOutput(normalizationMethod, DataType.DOUBLE, label);
            return normalizationMethod;
        }
        SchemaUtil.checkSize(2, label);
        if (!(skLearnVersion != null && VersionUtil.compareVersion(skLearnVersion, "0.20") >= 0)) {
            return encodeOvRModel(schema);
        }
        Model output = RegressionModelUtil.createRegression(features, CMatrixUtil.getRow(coef, 1, i2, 0), intercept.get(0), (RegressionModel.NormalizationMethod) null, schema.toRelabeledSchema((Label) null)).setOutput(ModelUtil.createPredictedOutput(Estimator.FIELD_DECISION_FUNCTION, OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0]));
        ContinuousFeature continuousFeature = new ContinuousFeature(encoder, Estimator.FIELD_DECISION_FUNCTION, DataType.DOUBLE);
        RegressionTable targetCategory = RegressionModelUtil.createRegressionTable(Collections.singletonList(continuousFeature), Collections.singletonList(Double.valueOf(-1.0d)), Double.valueOf(0.0d)).setTargetCategory(label.getValue(0));
        RegressionTable targetCategory2 = RegressionModelUtil.createRegressionTable(Collections.singletonList(continuousFeature), Collections.singletonList(Double.valueOf(1.0d)), Double.valueOf(0.0d)).setTargetCategory(label.getValue(1));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(targetCategory);
        arrayList2.add(targetCategory2);
        MiningModel createModelChain = MiningModelUtil.createModelChain(Arrays.asList(output, new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), arrayList2).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX)), Segmentation.MissingPredictionTreatment.RETURN_MISSING);
        encodePredictProbaOutput(createModelChain, DataType.DOUBLE, label);
        return createModelChain;
    }

    private Model encodeOvRModel(Schema schema) {
        return super.mo7encodeModel(schema);
    }

    public String getMultiClass() {
        String string = getString("multi_class");
        return "warn".equals(string) ? "ovr" : string;
    }

    public String getSolver() {
        return getString("solver");
    }

    private static String getAutoMultiClass(String str, int[] iArr) {
        int i = iArr[0];
        int i2 = iArr[1];
        return ("liblinear".equals(str) || i == 1) ? "ovr" : "multinomial";
    }
}
