package org.jpmml.sparkml.model;

import java.util.ArrayList;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.MatrixUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.VectorUtil;

/* loaded from: input_file:org/jpmml/sparkml/model/LinearModelUtil.class */
public class LinearModelUtil {
    /* JADX WARN: Incorrect types in method signature: <C:Lorg/jpmml/sparkml/ModelConverter<*>;:Lorg/jpmml/sparkml/model/HasRegressionTableOptions;>(TC;Lorg/apache/spark/ml/linalg/Vector;DLorg/jpmml/converter/Schema;)Lorg/dmg/pmml/regression/RegressionModel; */
    public static RegressionModel createRegression(ModelConverter modelConverter, Vector vector, double d, Schema schema) {
        schema.getLabel();
        ArrayList arrayList = new ArrayList(schema.getFeatures());
        ArrayList arrayList2 = new ArrayList(VectorUtil.toList(vector));
        RegressionTableUtil.simplify(modelConverter, null, arrayList, arrayList2);
        return RegressionModelUtil.createRegression(arrayList, arrayList2, Double.valueOf(d), RegressionModel.NormalizationMethod.NONE, schema);
    }

    /* JADX WARN: Incorrect types in method signature: <C:Lorg/jpmml/sparkml/ModelConverter<*>;:Lorg/jpmml/sparkml/model/HasRegressionTableOptions;>(TC;Lorg/apache/spark/ml/linalg/Vector;DLorg/jpmml/converter/Schema;)Lorg/dmg/pmml/regression/RegressionModel; */
    public static RegressionModel createBinaryLogisticClassification(ModelConverter modelConverter, Vector vector, double d, Schema schema) {
        schema.getLabel();
        ArrayList arrayList = new ArrayList(schema.getFeatures());
        ArrayList arrayList2 = new ArrayList(VectorUtil.toList(vector));
        RegressionTableUtil.simplify(modelConverter, null, arrayList, arrayList2);
        return RegressionModelUtil.createBinaryLogisticClassification(arrayList, arrayList2, Double.valueOf(d), RegressionModel.NormalizationMethod.LOGIT, true, schema);
    }

    /* JADX WARN: Incorrect types in method signature: <C:Lorg/jpmml/sparkml/ModelConverter<*>;:Lorg/jpmml/sparkml/model/HasRegressionTableOptions;>(TC;Lorg/apache/spark/ml/linalg/Matrix;Lorg/apache/spark/ml/linalg/Vector;Lorg/jpmml/converter/Schema;)Lorg/dmg/pmml/regression/RegressionModel; */
    public static RegressionModel createSoftmaxClassification(ModelConverter modelConverter, Matrix matrix, Vector vector, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        MatrixUtil.checkRows(label.size(), matrix);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < label.size(); i++) {
            Object value = label.getValue(i);
            ArrayList arrayList2 = new ArrayList(schema.getFeatures());
            ArrayList arrayList3 = new ArrayList(MatrixUtil.getRow(matrix, i));
            RegressionTableUtil.simplify(modelConverter, value, arrayList2, arrayList3);
            arrayList.add(RegressionModelUtil.createRegressionTable(arrayList2, arrayList3, Double.valueOf(vector.apply(i))).setTargetCategory(value));
        }
        return new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), arrayList).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
    }
}
