/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.ann.FeedForwardTopology;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.FeedForwardTrainer;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.classification.LabelConverter$;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronParams;
import org.apache.spark.ml.classification.MultilayerPerceptronParams$class;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntArrayParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasMaxIter$class;
import org.apache.spark.ml.param.shared.HasSeed$class;
import org.apache.spark.ml.param.shared.HasTol$class;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

@Experimental
@ScalaSignature(bytes="\u0006\u0001\u0005=a\u0001B\u0001\u0003\u00015\u0011a$T;mi&d\u0017-_3s!\u0016\u00148-\u001a9ue>t7\t\\1tg&4\u0017.\u001a:\u000b\u0005\r!\u0011AD2mCN\u001c\u0018NZ5dCRLwN\u001c\u0006\u0003\u000b\u0019\t!!\u001c7\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001aE\u0002\u0001\u001d}\u0001Ra\u0004\t\u00135qi\u0011\u0001B\u0005\u0003#\u0011\u0011\u0011\u0002\u0015:fI&\u001cGo\u001c:\u0011\u0005MAR\"\u0001\u000b\u000b\u0005U1\u0012A\u00027j]\u0006dwM\u0003\u0002\u0018\r\u0005)Q\u000e\u001c7jE&\u0011\u0011\u0004\u0006\u0002\u0007-\u0016\u001cGo\u001c:\u0011\u0005m\u0001Q\"\u0001\u0002\u0011\u0005mi\u0012B\u0001\u0010\u0003\u0005\u001djU\u000f\u001c;jY\u0006LXM\u001d)fe\u000e,\u0007\u000f\u001e:p]\u000ec\u0017m]:jM&\u001c\u0017\r^5p]6{G-\u001a7\u0011\u0005m\u0001\u0013BA\u0011\u0003\u0005iiU\u000f\u001c;jY\u0006LXM\u001d)fe\u000e,\u0007\u000f\u001e:p]B\u000b'/Y7t\u0011!\u0019\u0003A!b\u0001\n\u0003\"\u0013aA;jIV\tQ\u0005\u0005\u0002'Y9\u0011qEK\u0007\u0002Q)\t\u0011&A\u0003tG\u0006d\u0017-\u0003\u0002,Q\u00051\u0001K]3eK\u001aL!!\f\u0018\u0003\rM#(/\u001b8h\u0015\tY\u0003\u0006K\u0002#aY\u0002\"!\r\u001b\u000e\u0003IR!a\r\u0004\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u00026e\t)1+\u001b8dK\u0006\nq'A\u00032]Ur\u0003\u0007\u0003\u0005:\u0001\t\u0005\t\u0015!\u0003&\u0003\u0011)\u0018\u000e\u001a\u0011)\u0007a\u0002d\u0007C\u0003=\u0001\u0011\u0005Q(\u0001\u0004=S:LGO\u0010\u000b\u00035yBQaI\u001eA\u0002\u0015B3A\u0010\u00197Q\rY\u0004G\u000e\u0005\u0006y\u0001!\tA\u0011\u000b\u00025!\u001a\u0011\t\r\u001c\t\u000b\u0015\u0003A\u0011\u0001$\u0002\u0013M,G\u000fT1zKJ\u001cHCA$I\u001b\u0005\u0001\u0001\"B%E\u0001\u0004Q\u0015!\u0002<bYV,\u0007cA\u0014L\u001b&\u0011A\n\u000b\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0003O9K!a\u0014\u0015\u0003\u0007%sG\u000fK\u0002EaYBQA\u0015\u0001\u0005\u0002M\u000bAb]3u\u00052|7m[*ju\u0016$\"a\u0012+\t\u000b%\u000b\u0006\u0019A')\u0007E\u0003d\u0007C\u0003X\u0001\u0011\u0005\u0001,\u0001\u0006tKRl\u0015\r_%uKJ$\"aR-\t\u000b%3\u0006\u0019A')\u0007Y\u0003d\u0007C\u0003]\u0001\u0011\u0005Q,\u0001\u0004tKR$v\u000e\u001c\u000b\u0003\u000fzCQ!S.A\u0002}\u0003\"a\n1\n\u0005\u0005D#A\u0002#pk\ndW\rK\u0002\\aYBQ\u0001\u001a\u0001\u0005\u0002\u0015\fqa]3u'\u0016,G\r\u0006\u0002HM\")\u0011j\u0019a\u0001OB\u0011q\u0005[\u0005\u0003S\"\u0012A\u0001T8oO\"\u001a1\r\r\u001c\t\u000b1\u0004A\u0011I7\u0002\t\r|\u0007/\u001f\u000b\u000359DQa\\6A\u0002A\fQ!\u001a=ue\u0006\u0004\"!\u001d;\u000e\u0003IT!a\u001d\u0003\u0002\u000bA\f'/Y7\n\u0005U\u0014(\u0001\u0003)be\u0006lW*\u00199)\u0007-\u0004d\u0007C\u0003y\u0001\u0011E\u00130A\u0003ue\u0006Lg\u000e\u0006\u0002\u001du\")1p\u001ea\u0001y\u00069A-\u0019;bg\u0016$\bcA?\u0002\u00025\taP\u0003\u0002\u0000\r\u0005\u00191/\u001d7\n\u0007\u0005\raPA\u0005ECR\fgI]1nK\"\u001a\u0001!a\u0002\u0011\u0007E\nI!C\u0002\u0002\fI\u0012A\"\u0012=qKJLW.\u001a8uC2D3\u0001\u0001\u00197\u0001")
public class MultilayerPerceptronClassifier
extends Predictor<Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel>
implements MultilayerPerceptronParams {
    private final String uid;
    private final IntArrayParam layers;
    private final IntParam blockSize;
    private final DoubleParam tol;
    private final IntParam maxIter;
    private final LongParam seed;

    @Override
    public final IntArrayParam layers() {
        return this.layers;
    }

    @Override
    public final IntParam blockSize() {
        return this.blockSize;
    }

    @Override
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$layers_$eq(IntArrayParam x$1) {
        this.layers = x$1;
    }

    @Override
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$blockSize_$eq(IntParam x$1) {
        this.blockSize = x$1;
    }

    @Override
    public final int[] getLayers() {
        return MultilayerPerceptronParams$class.getLayers(this);
    }

    @Override
    public final int getBlockSize() {
        return MultilayerPerceptronParams$class.getBlockSize(this);
    }

    @Override
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam x$1) {
        this.tol = x$1;
    }

    @Override
    public final double getTol() {
        return HasTol$class.getTol(this);
    }

    @Override
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam x$1) {
        this.maxIter = x$1;
    }

    @Override
    public final int getMaxIter() {
        return HasMaxIter$class.getMaxIter(this);
    }

    @Override
    public final LongParam seed() {
        return this.seed;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasSeed$_setter_$seed_$eq(LongParam x$1) {
        this.seed = x$1;
    }

    @Override
    public final long getSeed() {
        return HasSeed$class.getSeed(this);
    }

    @Override
    public String uid() {
        return this.uid;
    }

    public MultilayerPerceptronClassifier setLayers(int[] value) {
        return (MultilayerPerceptronClassifier)this.set(this.layers(), value);
    }

    public MultilayerPerceptronClassifier setBlockSize(int value) {
        return (MultilayerPerceptronClassifier)this.set(this.blockSize(), BoxesRunTime.boxToInteger((int)value));
    }

    public MultilayerPerceptronClassifier setMaxIter(int value) {
        return (MultilayerPerceptronClassifier)this.set(this.maxIter(), BoxesRunTime.boxToInteger((int)value));
    }

    public MultilayerPerceptronClassifier setTol(double value) {
        return (MultilayerPerceptronClassifier)this.set(this.tol(), BoxesRunTime.boxToDouble((double)value));
    }

    public MultilayerPerceptronClassifier setSeed(long value) {
        return (MultilayerPerceptronClassifier)this.set(this.seed(), BoxesRunTime.boxToLong((long)value));
    }

    @Override
    public MultilayerPerceptronClassifier copy(ParamMap extra) {
        return (MultilayerPerceptronClassifier)this.defaultCopy(extra);
    }

    @Override
    public MultilayerPerceptronClassificationModel train(DataFrame dataset) {
        int[] myLayers = this.$(this.layers());
        int labels = BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps(myLayers).last());
        RDD<LabeledPoint> lpData = this.extractLabeledPoints(dataset);
        RDD data = lpData.map((Function1)new Serializable(this, labels){
            public static final long serialVersionUID = 0L;
            private final int labels$1;

            public final Tuple2<Vector, Vector> apply(LabeledPoint lp) {
                return LabelConverter$.MODULE$.encodeLabeledPoint(lp, this.labels$1);
            }
            {
                this.labels$1 = labels$1;
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class));
        FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(myLayers, true);
        FeedForwardTrainer FeedForwardTrainer2 = new FeedForwardTrainer(topology, myLayers[0], BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps(myLayers).last()));
        FeedForwardTrainer2.LBFGSOptimizer().setConvergenceTol(BoxesRunTime.unboxToDouble((Object)this.$(this.tol()))).setNumIterations(BoxesRunTime.unboxToInt((Object)this.$(this.maxIter())));
        FeedForwardTrainer2.setStackSize(BoxesRunTime.unboxToInt((Object)this.$(this.blockSize())));
        TopologyModel mlpModel = FeedForwardTrainer2.train((RDD<Tuple2<Vector, Vector>>)data);
        return new MultilayerPerceptronClassificationModel(this.uid(), myLayers, mlpModel.weights());
    }

    public MultilayerPerceptronClassifier(String uid) {
        this.uid = uid;
        HasSeed$class.$init$(this);
        HasMaxIter$class.$init$(this);
        HasTol$class.$init$(this);
        MultilayerPerceptronParams$class.$init$(this);
    }

    public MultilayerPerceptronClassifier() {
        this(Identifiable$.MODULE$.randomUID("mlpc"));
    }
}

