package org.jpmml.rexp;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.model.ValueUtil;

/* loaded from: input_file:org/jpmml/rexp/MultNetConverter.class */
public class MultNetConverter extends GLMNetConverter {
    public MultNetConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
    }

    @Override // org.jpmml.rexp.GLMNetConverter
    public Model encodeModel(RDoubleVector rDoubleVector, RExp rExp, int i, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        RIntegerVector dim = rDoubleVector.dim();
        int intValue = dim.getValue(0).intValue();
        int intValue2 = dim.getValue(1).intValue();
        RGenericVector rGenericVector = (RGenericVector) rExp;
        if (label.size() == 2) {
            List row = FortranMatrixUtil.getRow(rDoubleVector.getValues(), intValue, intValue2, 1);
            S4Object s4Object = (S4Object) rGenericVector.getValue(1);
            Function<Double, Double> function = new Function<Double, Double>() { // from class: org.jpmml.rexp.MultNetConverter.1
                public Double apply(Double d) {
                    return Double.valueOf(2.0d * d.doubleValue());
                }
            };
            return RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), Lists.transform(getCoefficients(s4Object, i), function), (Double) function.apply(row.get(i)), RegressionModel.NormalizationMethod.LOGIT, true, schema);
        }
        if (label.size() <= 2) {
            throw new IllegalArgumentException();
        }
        RegressionModel output = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), (List) null).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, label));
        for (int i2 = 0; i2 < label.size(); i2++) {
            Object value = label.getValue(i2);
            List row2 = FortranMatrixUtil.getRow(rDoubleVector.getValues(), intValue, intValue2, i2);
            output.addRegressionTables(new RegressionTable[]{RegressionModelUtil.createRegressionTable(schema.getFeatures(), getCoefficients((S4Object) rGenericVector.getElement(ValueUtil.toString(value)), i), (Double) row2.get(i)).setTargetCategory(value)});
        }
        return output;
    }
}
