/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.math.Field;
import breeze.math.MutableInnerProductModule;
import breeze.optimize.CachedDiffFunction;
import breeze.optimize.DiffFunction;
import breeze.optimize.FirstOrderMinimizer;
import breeze.optimize.LBFGS;
import breeze.optimize.StochasticDiffFunction;
import org.apache.spark.Logging;
import org.apache.spark.SparkException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.shared.HasFeaturesCol$class;
import org.apache.spark.ml.param.shared.HasFitIntercept$class;
import org.apache.spark.ml.param.shared.HasLabelCol$class;
import org.apache.spark.ml.param.shared.HasMaxIter$class;
import org.apache.spark.ml.param.shared.HasPredictionCol$class;
import org.apache.spark.ml.param.shared.HasTol$class;
import org.apache.spark.ml.regression.AFTCostFun;
import org.apache.spark.ml.regression.AFTPoint;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.ml.regression.AFTSurvivalRegressionParams;
import org.apache.spark.ml.regression.AFTSurvivalRegressionParams$class;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.ArrayBuilder$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@Experimental
@ScalaSignature(bytes="\u0006\u0001\u0005\rd\u0001B\u0001\u0003\u00015\u0011Q#\u0011$U'V\u0014h/\u001b<bYJ+wM]3tg&|gN\u0003\u0002\u0004\t\u0005Q!/Z4sKN\u001c\u0018n\u001c8\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7\u0001A\n\u0005\u000191\u0012\u0004E\u0002\u0010!Ii\u0011\u0001B\u0005\u0003#\u0011\u0011\u0011\"R:uS6\fGo\u001c:\u0011\u0005M!R\"\u0001\u0002\n\u0005U\u0011!AG!G)N+(O^5wC2\u0014Vm\u001a:fgNLwN\\'pI\u0016d\u0007CA\n\u0018\u0013\tA\"AA\u000eB\rR\u001bVO\u001d<jm\u0006d'+Z4sKN\u001c\u0018n\u001c8QCJ\fWn\u001d\t\u00035mi\u0011AB\u0005\u00039\u0019\u0011q\u0001T8hO&tw\r\u0003\u0005\u001f\u0001\t\u0015\r\u0011\"\u0011 \u0003\r)\u0018\u000eZ\u000b\u0002AA\u0011\u0011e\n\b\u0003E\u0015j\u0011a\t\u0006\u0002I\u0005)1oY1mC&\u0011aeI\u0001\u0007!J,G-\u001a4\n\u0005!J#AB*ue&twM\u0003\u0002'G!\u001aQdK\u0019\u0011\u00051zS\"A\u0017\u000b\u000592\u0011AC1o]>$\u0018\r^5p]&\u0011\u0001'\f\u0002\u0006'&t7-Z\u0011\u0002e\u0005)\u0011G\f\u001c/a!AA\u0007\u0001B\u0001B\u0003%\u0001%\u0001\u0003vS\u0012\u0004\u0003fA\u001a,c!)q\u0007\u0001C\u0001q\u00051A(\u001b8jiz\"\"!\u000f\u001e\u0011\u0005M\u0001\u0001\"\u0002\u00107\u0001\u0004\u0001\u0003f\u0001\u001e,c!\u001aagK\u0019\t\u000b]\u0002A\u0011\u0001 \u0015\u0003eB3!P\u00162\u0011\u0015\t\u0005\u0001\"\u0001C\u00039\u0019X\r\u001e$fCR,(/Z:D_2$\"a\u0011#\u000e\u0003\u0001AQ!\u0012!A\u0002\u0001\nQA^1mk\u0016D3\u0001Q\u00162\u0011\u0015A\u0005\u0001\"\u0001J\u0003-\u0019X\r\u001e'bE\u0016d7i\u001c7\u0015\u0005\rS\u0005\"B#H\u0001\u0004\u0001\u0003fA$,c!)Q\n\u0001C\u0001\u001d\u0006a1/\u001a;DK:\u001cxN]\"pYR\u00111i\u0014\u0005\u0006\u000b2\u0003\r\u0001\t\u0015\u0004\u0019.\n\u0004\"\u0002*\u0001\t\u0003\u0019\u0016\u0001E:fiB\u0013X\rZ5di&|gnQ8m)\t\u0019E\u000bC\u0003F#\u0002\u0007\u0001\u0005K\u0002RWEBQa\u0016\u0001\u0005\u0002a\u000b\u0001d]3u#V\fg\u000e^5mKB\u0013xNY1cS2LG/[3t)\t\u0019\u0015\fC\u0003F-\u0002\u0007!\fE\u0002#7vK!\u0001X\u0012\u0003\u000b\u0005\u0013(/Y=\u0011\u0005\tr\u0016BA0$\u0005\u0019!u.\u001e2mK\"\u001aakK\u0019\t\u000b\t\u0004A\u0011A2\u0002\u001fM,G/U;b]RLG.Z:D_2$\"a\u00113\t\u000b\u0015\u000b\u0007\u0019\u0001\u0011)\u0007\u0005\\\u0013\u0007C\u0003h\u0001\u0011\u0005\u0001.A\btKR4\u0015\u000e^%oi\u0016\u00148-\u001a9u)\t\u0019\u0015\u000eC\u0003FM\u0002\u0007!\u000e\u0005\u0002#W&\u0011An\t\u0002\b\u0005>|G.Z1oQ\r17&\r\u0005\u0006_\u0002!\t\u0001]\u0001\u000bg\u0016$X*\u0019=Ji\u0016\u0014HCA\"r\u0011\u0015)e\u000e1\u0001s!\t\u00113/\u0003\u0002uG\t\u0019\u0011J\u001c;)\u00079\\\u0013\u0007C\u0003x\u0001\u0011\u0005\u00010\u0001\u0004tKR$v\u000e\u001c\u000b\u0003\u0007fDQ!\u0012<A\u0002uC3A^\u00162\u0011\u0019a\b\u0001\"\u0005\u0005{\u0006\u0001R\r\u001f;sC\u000e$\u0018I\u0012+Q_&tGo\u001d\u000b\u0004}\u0006=\u0001#B@\u0002\u0006\u0005%QBAA\u0001\u0015\r\t\u0019AB\u0001\u0004e\u0012$\u0017\u0002BA\u0004\u0003\u0003\u00111A\u0015#E!\r\u0019\u00121B\u0005\u0004\u0003\u001b\u0011!\u0001C!G)B{\u0017N\u001c;\t\u000f\u0005E1\u00101\u0001\u0002\u0014\u00059A-\u0019;bg\u0016$\b\u0003BA\u000b\u00037i!!a\u0006\u000b\u0007\u0005ea!A\u0002tc2LA!!\b\u0002\u0018\tIA)\u0019;b\rJ\fW.\u001a\u0005\b\u0003C\u0001A\u0011IA\u0012\u0003\r1\u0017\u000e\u001e\u000b\u0004%\u0005\u0015\u0002\u0002CA\t\u0003?\u0001\r!a\u0005)\t\u0005}1&\r\u0005\b\u0003W\u0001A\u0011IA\u0017\u0003=!(/\u00198tM>\u0014XnU2iK6\fG\u0003BA\u0018\u0003w\u0001B!!\r\u000285\u0011\u00111\u0007\u0006\u0005\u0003k\t9\"A\u0003usB,7/\u0003\u0003\u0002:\u0005M\"AC*ueV\u001cG\u000fV=qK\"A\u0011QHA\u0015\u0001\u0004\ty#\u0001\u0004tG\",W.\u0019\u0015\u0005\u0003SY\u0013\u0007C\u0004\u0002D\u0001!\t%!\u0012\u0002\t\r|\u0007/\u001f\u000b\u0004s\u0005\u001d\u0003\u0002CA%\u0003\u0003\u0002\r!a\u0013\u0002\u000b\u0015DHO]1\u0011\t\u00055\u00131K\u0007\u0003\u0003\u001fR1!!\u0015\u0005\u0003\u0015\u0001\u0018M]1n\u0013\u0011\t)&a\u0014\u0003\u0011A\u000b'/Y7NCBDC!!\u0011,c!\u001a\u0001aK\u0019)\u0007\u0001\ti\u0006E\u0002-\u0003?J1!!\u0019.\u00051)\u0005\u0010]3sS6,g\u000e^1m\u0001")
public class AFTSurvivalRegression
extends Estimator<AFTSurvivalRegressionModel>
implements AFTSurvivalRegressionParams,
Logging {
    private final String uid;
    private final Param<String> censorCol;
    private final DoubleArrayParam quantileProbabilities;
    private final Param<String> quantilesCol;
    private final BooleanParam fitIntercept;
    private final DoubleParam tol;
    private final IntParam maxIter;
    private final Param<String> predictionCol;
    private final Param<String> labelCol;
    private final Param<String> featuresCol;

    @Override
    public final Param<String> censorCol() {
        return this.censorCol;
    }

    @Override
    public final DoubleArrayParam quantileProbabilities() {
        return this.quantileProbabilities;
    }

    @Override
    public final Param<String> quantilesCol() {
        return this.quantilesCol;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$censorCol_$eq(Param x$1) {
        this.censorCol = x$1;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantileProbabilities_$eq(DoubleArrayParam x$1) {
        this.quantileProbabilities = x$1;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantilesCol_$eq(Param x$1) {
        this.quantilesCol = x$1;
    }

    @Override
    public String getCensorCol() {
        return AFTSurvivalRegressionParams$class.getCensorCol(this);
    }

    @Override
    public double[] getQuantileProbabilities() {
        return AFTSurvivalRegressionParams$class.getQuantileProbabilities(this);
    }

    @Override
    public String getQuantilesCol() {
        return AFTSurvivalRegressionParams$class.getQuantilesCol(this);
    }

    @Override
    public boolean hasQuantilesCol() {
        return AFTSurvivalRegressionParams$class.hasQuantilesCol(this);
    }

    @Override
    public StructType validateAndTransformSchema(StructType schema, boolean fitting) {
        return AFTSurvivalRegressionParams$class.validateAndTransformSchema(this, schema, fitting);
    }

    @Override
    public final BooleanParam fitIntercept() {
        return this.fitIntercept;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam x$1) {
        this.fitIntercept = x$1;
    }

    @Override
    public final boolean getFitIntercept() {
        return HasFitIntercept$class.getFitIntercept(this);
    }

    @Override
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam x$1) {
        this.tol = x$1;
    }

    @Override
    public final double getTol() {
        return HasTol$class.getTol(this);
    }

    @Override
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam x$1) {
        this.maxIter = x$1;
    }

    @Override
    public final int getMaxIter() {
        return HasMaxIter$class.getMaxIter(this);
    }

    @Override
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param x$1) {
        this.predictionCol = x$1;
    }

    @Override
    public final String getPredictionCol() {
        return HasPredictionCol$class.getPredictionCol(this);
    }

    @Override
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param x$1) {
        this.labelCol = x$1;
    }

    @Override
    public final String getLabelCol() {
        return HasLabelCol$class.getLabelCol(this);
    }

    @Override
    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param x$1) {
        this.featuresCol = x$1;
    }

    @Override
    public final String getFeaturesCol() {
        return HasFeaturesCol$class.getFeaturesCol(this);
    }

    @Override
    public String uid() {
        return this.uid;
    }

    public AFTSurvivalRegression setFeaturesCol(String value) {
        return (AFTSurvivalRegression)this.set(this.featuresCol(), value);
    }

    public AFTSurvivalRegression setLabelCol(String value) {
        return (AFTSurvivalRegression)this.set(this.labelCol(), value);
    }

    public AFTSurvivalRegression setCensorCol(String value) {
        return (AFTSurvivalRegression)this.set(this.censorCol(), value);
    }

    public AFTSurvivalRegression setPredictionCol(String value) {
        return (AFTSurvivalRegression)this.set(this.predictionCol(), value);
    }

    public AFTSurvivalRegression setQuantileProbabilities(double[] value) {
        return (AFTSurvivalRegression)this.set(this.quantileProbabilities(), value);
    }

    public AFTSurvivalRegression setQuantilesCol(String value) {
        return (AFTSurvivalRegression)this.set(this.quantilesCol(), value);
    }

    public AFTSurvivalRegression setFitIntercept(boolean value) {
        return (AFTSurvivalRegression)this.set(this.fitIntercept(), BoxesRunTime.boxToBoolean((boolean)value));
    }

    public AFTSurvivalRegression setMaxIter(int value) {
        return (AFTSurvivalRegression)this.set(this.maxIter(), BoxesRunTime.boxToInteger((int)value));
    }

    public AFTSurvivalRegression setTol(double value) {
        return (AFTSurvivalRegression)this.set(this.tol(), BoxesRunTime.boxToDouble((double)value));
    }

    public RDD<AFTPoint> extractAFTPoints(DataFrame dataset) {
        return dataset.select(this.$(this.featuresCol()), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{this.$(this.labelCol()), this.$(this.censorCol())})).map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final AFTPoint apply(Row x0$1) {
                Row row = x0$1;
                Some some = Row$.MODULE$.unapplySeq(row);
                if (!some.isEmpty() && some.get() != null && ((SeqLike)some.get()).lengthCompare(3) == 0) {
                    Object features = ((SeqLike)some.get()).apply(0);
                    Object label = ((SeqLike)some.get()).apply(1);
                    Object censor = ((SeqLike)some.get()).apply(2);
                    if (features instanceof Vector) {
                        Vector vector = (Vector)features;
                        if (label instanceof Double) {
                            double d = BoxesRunTime.unboxToDouble((Object)label);
                            if (censor instanceof Double) {
                                double d2 = BoxesRunTime.unboxToDouble((Object)censor);
                                AFTPoint aFTPoint = new AFTPoint(vector, d, d2);
                                return aFTPoint;
                            }
                        }
                    }
                }
                throw new MatchError((Object)row);
            }
        }, ClassTag$.MODULE$.apply(AFTPoint.class));
    }

    @Override
    public AFTSurvivalRegressionModel fit(DataFrame dataset) {
        this.validateAndTransformSchema(dataset.schema(), true);
        RDD<AFTPoint> instances = this.extractAFTPoints(dataset);
        StorageLevel storageLevel = dataset.rdd().getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        boolean handlePersistence = !(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null);
        Object object = handlePersistence ? instances.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : BoxedUnit.UNIT;
        AFTCostFun costFun = new AFTCostFun(instances, BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept())));
        LBFGS optimizer = new LBFGS(BoxesRunTime.unboxToInt((Object)this.$(this.maxIter())), 10, BoxesRunTime.unboxToDouble((Object)this.$(this.tol())), (MutableInnerProductModule)DenseVector$.MODULE$.space((Field)Field.fieldDouble$.MODULE$, ClassTag$.MODULE$.Double()));
        int numFeatures = ((Vector)dataset.select(this.$(this.featuresCol()), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])).take(1)[0].getAs(0)).size();
        Vector initialWeights = Vectors$.MODULE$.zeros(numFeatures + 2);
        Iterator states = optimizer.iterations((StochasticDiffFunction)new CachedDiffFunction((DiffFunction)costFun, DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double())), (Object)initialWeights.toBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double()));
        ArrayBuilder arrayBuilder = ArrayBuilder$.MODULE$.make(ClassTag$.MODULE$.Double());
        FirstOrderMinimizer.State state = null;
        while (states.hasNext()) {
            state = (FirstOrderMinimizer.State)states.next();
            arrayBuilder.$plus$eq((Object)BoxesRunTime.boxToDouble((double)state.adjustedValue()));
        }
        if (state == null) {
            String msg = new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " failed."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{optimizer.getClass().getName()}));
            throw new SparkException(msg);
        }
        double[] weights2 = (double[])((DenseVector)state.x()).toArray$mcD$sp(ClassTag$.MODULE$.Double()).clone();
        Object object2 = handlePersistence ? instances.unpersist(instances.unpersist$default$1()) : BoxedUnit.UNIT;
        Vector coefficients = Vectors$.MODULE$.dense((double[])Predef$.MODULE$.doubleArrayOps(weights2).slice(2, weights2.length));
        double intercept = weights2[1];
        double scale = package$.MODULE$.exp(weights2[0]);
        AFTSurvivalRegressionModel model = new AFTSurvivalRegressionModel(this.uid(), coefficients, intercept, scale);
        return this.copyValues(model.setParent(this), this.copyValues$default$2());
    }

    @Override
    public StructType transformSchema(StructType schema) {
        return this.validateAndTransformSchema(schema, true);
    }

    @Override
    public AFTSurvivalRegression copy(ParamMap extra) {
        return (AFTSurvivalRegression)this.defaultCopy(extra);
    }

    public AFTSurvivalRegression(String uid) {
        this.uid = uid;
        HasFeaturesCol$class.$init$(this);
        HasLabelCol$class.$init$(this);
        HasPredictionCol$class.$init$(this);
        HasMaxIter$class.$init$(this);
        HasTol$class.$init$(this);
        HasFitIntercept$class.$init$(this);
        AFTSurvivalRegressionParams$class.$init$(this);
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.fitIntercept().$minus$greater(BoxesRunTime.boxToBoolean((boolean)true))}));
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.maxIter().$minus$greater(BoxesRunTime.boxToInteger((int)100))}));
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.tol().$minus$greater(BoxesRunTime.boxToDouble((double)1.0E-6))}));
    }

    public AFTSurvivalRegression() {
        this(Identifiable$.MODULE$.randomUID("aftSurvReg"));
    }
}

