package edu.columbia.tjw.item.spark;

import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.base.StandardCurveType;
import edu.columbia.tjw.item.fit.ItemFitter;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.util.random.RandomTool;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;

/* loaded from: input_file:edu/columbia/tjw/item/spark/ItemClassifier.class */
public class ItemClassifier<S extends ItemStatus<S>, R extends ItemRegressor<R>> extends ProbabilisticClassifier<Vector, ItemClassifier<S, R>, ItemClassificationModel<S, R>> implements Cloneable {
    private static final long serialVersionUID = 8990051165227355116L;
    private final ItemClassifierSettings<S, R> _settings;
    private final ItemParameters<S, R, StandardCurveType> _startingParams;
    private String _uid;

    public ItemClassifier(ItemClassifierSettings<S, R> itemClassifierSettings) {
        this(itemClassifierSettings, null);
    }

    public ItemClassifier(ItemClassifierSettings<S, R> itemClassifierSettings, ItemParameters<S, R, StandardCurveType> itemParameters) {
        if (null == itemClassifierSettings) {
            throw new NullPointerException("Settings cannot be null.");
        }
        this._settings = itemClassifierSettings;
        this._startingParams = itemParameters;
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public ItemClassifier<S, R> m32copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public ItemClassificationModel<S, R> train(Dataset<?> dataset) {
        ItemFitter itemFitter = new ItemFitter(this._settings.getFactory(), this._settings.getIntercept(), this._settings.getFromStatus(), new SparkGridAdapter(dataset, getLabelCol(), getFeaturesCol(), this._settings.getRegressors(), this._settings.getFromStatus(), this._settings.getIntercept()));
        if (null != this._startingParams) {
            try {
                itemFitter.pushParameters("InitialParams", this._startingParams);
            } catch (ConvergenceException e) {
                throw new RuntimeException(e);
            }
        }
        int maxParamCount = this._settings.getMaxParamCount();
        for (int i = 0; i < 3; i++) {
            try {
                if (maxParamCount - itemFitter.getBestParameters().getEffectiveParamCount() < 1) {
                    break;
                }
                itemFitter.fitCoefficients(null);
                itemFitter.addCoefficients(null, this._settings.getNonCurveRegressors());
                int effectiveParamCount = maxParamCount - itemFitter.getBestParameters().getEffectiveParamCount();
                if (effectiveParamCount <= 3) {
                    break;
                }
                itemFitter.expandModel(this._settings.getCurveRegressors(), null, effectiveParamCount);
                itemFitter.calibrateCurves();
                itemFitter.fitCoefficients(null);
                itemFitter.trim(true);
                itemFitter.runAnnealingByEntry(this._settings.getCurveRegressors(), false);
            } catch (ConvergenceException e2) {
                throw new RuntimeException(e2);
            }
        }
        return new ItemClassificationModel<>(itemFitter.getBestParameters(), this._settings.getRegressors());
    }

    public synchronized String uid() {
        if (null == this._uid) {
            this._uid = RandomTool.randomString(64);
        }
        return this._uid;
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ PredictionModel m29train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }
}
