package org.apache.mxnet.spark;

import org.apache.mxnet.Accuracy;
import org.apache.mxnet.Context;
import org.apache.mxnet.DataIter;
import org.apache.mxnet.FeedForward;
import org.apache.mxnet.FeedForward$;
import org.apache.mxnet.KVStore;
import org.apache.mxnet.KVStore$;
import org.apache.mxnet.KVStoreServer$;
import org.apache.mxnet.NDArray;
import org.apache.mxnet.Optimizer;
import org.apache.mxnet.Shape;
import org.apache.mxnet.Symbol;
import org.apache.mxnet.Xavier;
import org.apache.mxnet.Xavier$;
import org.apache.mxnet.spark.io.LabeledPointIter;
import org.apache.mxnet.spark.utils.Network$;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function3;
import scala.Predef$;
import scala.Serializable;
import scala.collection.immutable.Map;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: MXNet.scala */
@ScalaSignature(bytes = "\u0006\u0001\tma\u0001B\u0001\u0003\u0001-\u0011Q!\u0014-OKRT!a\u0001\u0003\u0002\u000bM\u0004\u0018M]6\u000b\u0005\u00151\u0011!B7y]\u0016$(BA\u0004\t\u0003\u0019\t\u0007/Y2iK*\t\u0011\"A\u0002pe\u001e\u001c\u0001aE\u0002\u0001\u0019I\u0001\"!\u0004\t\u000e\u00039Q\u0011aD\u0001\u0006g\u000e\fG.Y\u0005\u0003#9\u0011a!\u00118z%\u00164\u0007CA\u0007\u0014\u0013\t!bB\u0001\u0007TKJL\u0017\r\\5{C\ndW\rC\u0003\u0017\u0001\u0011\u0005q#\u0001\u0004=S:LGO\u0010\u000b\u00021A\u0011\u0011\u0004A\u0007\u0002\u0005\u0019!1\u0004\u0001\u0001\u001d\u0005Yi\u0005LT3u\u0007>tGO]8mY&tw\r\u00165sK\u0006$7C\u0001\u000e\u001e!\tq2%D\u0001 \u0015\t\u0001\u0013%\u0001\u0003mC:<'\"\u0001\u0012\u0002\t)\fg/Y\u0005\u0003I}\u0011a\u0001\u00165sK\u0006$\u0007\u0002\u0003\u0014\u001b\u0005\u0003\u0005\u000b\u0011B\u0014\u0002\u0017M\u001c\u0007.\u001a3vY\u0016\u0014\u0018\n\u0015\t\u0003Q-r!!D\u0015\n\u0005)r\u0011A\u0002)sK\u0012,g-\u0003\u0002-[\t11\u000b\u001e:j]\u001eT!A\u000b\b\t\u0011=R\"\u0011!Q\u0001\nA\nQb]2iK\u0012,H.\u001a:Q_J$\bCA\u00072\u0013\t\u0011dBA\u0002J]RD\u0001\u0002\u000e\u000e\u0003\u0002\u0003\u0006I!N\u0001\rgB\f'o[\"p]R,\u0007\u0010\u001e\t\u0003maj\u0011a\u000e\u0006\u0003\u0007\u0019I!!O\u001c\u0003\u0019M\u0003\u0018M]6D_:$X\r\u001f;\t\u0011mR\"\u0011!Q\u0001\nq\n!\u0003\u001e:jO\u001e,'o\u00144D_6\u0004xN\\3oiB1Q\"P\u00141k}J!A\u0010\b\u0003\u0013\u0019+hn\u0019;j_:\u001c\u0004CA\u0007A\u0013\t\teB\u0001\u0003V]&$\b\"\u0002\f\u001b\t\u0003\u0019E#\u0002#G\u000f\"K\u0005CA#\u001b\u001b\u0005\u0001\u0001\"\u0002\u0014C\u0001\u00049\u0003\"B\u0018C\u0001\u0004\u0001\u0004\"\u0002\u001bC\u0001\u0004)\u0004\"B\u001eC\u0001\u0004a\u0004\"B&\u001b\t\u0003b\u0015a\u0001:v]R\tq\bC\u0004O\u0001\t\u0007I\u0011B(\u0002\r1|wmZ3s+\u0005\u0001\u0006CA)U\u001b\u0005\u0011&BA*\t\u0003\u0015\u0019HN\u001a\u001bk\u0013\t)&K\u0001\u0004M_\u001e<WM\u001d\u0005\u0007/\u0002\u0001\u000b\u0011\u0002)\u0002\u000f1|wmZ3sA!9\u0011\f\u0001b\u0001\n\u0013Q\u0016A\u00029be\u0006l7/F\u0001\\!\tIB,\u0003\u0002^\u0005\tYQ\n\u0017(fiB\u000b'/Y7t\u0011\u0019y\u0006\u0001)A\u00057\u00069\u0001/\u0019:b[N\u0004\u0003\"B1\u0001\t\u0003\u0011\u0017\u0001D:fi\n\u000bGo\u00195TSj,GCA#d\u0011\u0015!\u0007\r1\u00011\u0003%\u0011\u0017\r^2i'&TX\rC\u0003g\u0001\u0011\u0005q-A\u0006tKRtU/\\#q_\u000eDGCA#i\u0011\u0015IW\r1\u00011\u0003!qW/\\#q_\u000eD\u0007\"B6\u0001\t\u0003a\u0017\u0001D:fi\u0012KW.\u001a8tS>tGCA#n\u0011\u0015q'\u000e1\u0001p\u0003%!\u0017.\\3og&|g\u000e\u0005\u0002qc6\tA!\u0003\u0002s\t\t)1\u000b[1qK\")A\u000f\u0001C\u0001k\u0006Q1/\u001a;OKR<xN]6\u0015\u0005\u00153\b\"B<t\u0001\u0004A\u0018a\u00028fi^|'o\u001b\t\u0003afL!A\u001f\u0003\u0003\rMKXNY8m\u0011\u0015a\b\u0001\"\u0001~\u0003)\u0019X\r^\"p]R,\u0007\u0010\u001e\u000b\u0003\u000bzDaa`>A\u0002\u0005\u0005\u0011aA2uqB)Q\"a\u0001\u0002\b%\u0019\u0011Q\u0001\b\u0003\u000b\u0005\u0013(/Y=\u0011\u0007A\fI!C\u0002\u0002\f\u0011\u0011qaQ8oi\u0016DH\u000fC\u0004\u0002\u0010\u0001!\t!!\u0005\u0002\u0019M,GOT;n/>\u00148.\u001a:\u0015\u0007\u0015\u000b\u0019\u0002C\u0004\u0002\u0016\u00055\u0001\u0019\u0001\u0019\u0002\u00139,XnV8sW\u0016\u0014\bbBA\r\u0001\u0011\u0005\u00111D\u0001\rg\u0016$h*^7TKJ4XM\u001d\u000b\u0004\u000b\u0006u\u0001bBA\u0010\u0003/\u0001\r\u0001M\u0001\n]Vl7+\u001a:wKJDq!a\t\u0001\t\u0003\t)#A\u0006tKR$\u0015\r^1OC6,GcA#\u0002(!9\u0011\u0011FA\u0011\u0001\u00049\u0013\u0001\u00028b[\u0016Dq!!\f\u0001\t\u0003\ty#\u0001\u0007tKRd\u0015MY3m\u001d\u0006lW\rF\u0002F\u0003cAq!!\u000b\u0002,\u0001\u0007q\u0005C\u0004\u00026\u0001!\t!a\u000e\u0002\u0015M,G\u000fV5nK>,H\u000fF\u0002F\u0003sAq!a\u000f\u00024\u0001\u0007\u0001'A\u0004uS6,w.\u001e;\t\u000f\u0005}\u0002\u0001\"\u0001\u0002B\u0005y1/\u001a;Fq\u0016\u001cW\u000f^8s\u0015\u0006\u00148\u000fF\u0002F\u0003\u0007Bq!!\u0012\u0002>\u0001\u0007q%\u0001\u0003kCJ\u001c\bbBA%\u0001\u0011\u0005\u00111J\u0001\bg\u0016$(*\u0019<b)\r)\u0015Q\n\u0005\u0007E\u0005\u001d\u0003\u0019A\u0014\t\u000f\u0005E\u0003\u0001\"\u0003\u0002T\u0005q1\u000f^1siB\u001b6+\u001a:wKJ\u001cHcB \u0002V\u0005]\u0013\u0011\f\u0005\u0007M\u0005=\u0003\u0019A\u0014\t\r=\ny\u00051\u00011\u0011\u001d\tY&a\u0014A\u0002U\n!a]2\t\u000f\u0005}\u0003\u0001\"\u0003\u0002b\u0005\u00012\u000f^1siB\u001b6k\u00195fIVdWM\u001d\u000b\b\u007f\u0005\r\u0014QMA4\u0011\u00191\u0013Q\fa\u0001O!1q&!\u0018A\u0002ABq!a\u0017\u0002^\u0001\u0007Q\u0007C\u0004\u0002l\u0001!I!!\u001c\u0002'M,GOR3fI\u001a{'o^1sI6{G-\u001a7\u0015\u0015\u0005=\u0014QOA@\u0003\u0007\u000bi\tE\u0002q\u0003cJ1!a\u001d\u0005\u0005-1U-\u001a3G_J<\u0018M\u001d3\t\u0011\u0005]\u0014\u0011\u000ea\u0001\u0003s\n\u0011b\u001c9uS6L'0\u001a:\u0011\u0007A\fY(C\u0002\u0002~\u0011\u0011\u0011b\u00149uS6L'0\u001a:\t\u000f\u0005\u0005\u0015\u0011\u000ea\u0001a\u0005Ya.^7Fq\u0006l\u0007\u000f\\3t\u0011!\t))!\u001bA\u0002\u0005\u001d\u0015AA6w!\r\u0001\u0018\u0011R\u0005\u0004\u0003\u0017#!aB&W'R|'/\u001a\u0005\t\u0003\u001f\u000bI\u00071\u0001\u0002\u0012\u0006\u0001\u0012N\u001c9vi&s\u0007+\u0019:uSRLwN\u001c\t\u0005\u0003'\u000bI*\u0004\u0002\u0002\u0016*\u0019\u0011q\u0013\u0002\u0002\u0005%|\u0017\u0002BAN\u0003+\u0013\u0001\u0003T1cK2,G\rU8j]RLE/\u001a:\t\u000f\u0005}\u0005\u0001\"\u0003\u0002\"\u0006a1/\u001a;va.36\u000b^8sKR1\u0011qQAR\u0003KCaAJAO\u0001\u00049\u0003BB\u0018\u0002\u001e\u0002\u0007\u0001\u0007C\u0004\u0002*\u0002!I!a+\u0002!I,7\r\\1j[J+7o\\;sG\u0016\u001cH#B \u0002.\u0006E\u0006\u0002CAX\u0003O\u0003\r!!%\u0002\u0011\u0011\fG/Y%uKJD\u0001\"!\"\u0002(\u0002\u0007\u0011q\u0011\u0005\b\u0003k\u0003A\u0011BA\\\u0003)!(/Y5o\u001b>$W\r\u001c\u000b\t\u0003s\u000by,a8\u0002bB\u0019\u0011$a/\n\u0007\u0005u&A\u0001\u0006N1:+G/T8eK2D\u0001\"!1\u00024\u0002\u0007\u00111Y\u0001\niJ\f\u0017N\u001c#bi\u0006\u0004b!!2\u0002L\u0006=WBAAd\u0015\r\tImN\u0001\u0004e\u0012$\u0017\u0002BAg\u0003\u000f\u00141A\u0015#E!\u0011\t\t.a7\u000e\u0005\u0005M'\u0002BAk\u0003/\f!B]3he\u0016\u001c8/[8o\u0015\r\tInN\u0001\u0006[2d\u0017NY\u0005\u0005\u0003;\f\u0019N\u0001\u0007MC\n,G.\u001a3Q_&tG\u000f\u0003\u0004'\u0003g\u0003\ra\n\u0005\u0007_\u0005M\u0006\u0019\u0001\u0019\t\u000f\u0005\u0015\b\u0001\"\u0001\u0002h\u0006\u0019a-\u001b;\u0015\t\u0005e\u0016\u0011\u001e\u0005\t\u0003W\f\u0019\u000f1\u0001\u0002D\u0006!A-\u0019;b\u0011-\ty\u000f\u0001a\u0001\u0002\u0004%I!!=\u0002\u001dA\u001c8+\u001a:wKJ$\u0006N]3bIV\tA\tC\u0006\u0002v\u0002\u0001\r\u00111A\u0005\n\u0005]\u0018A\u00059t'\u0016\u0014h/\u001a:UQJ,\u0017\rZ0%KF$2aPA}\u0011%\tY0a=\u0002\u0002\u0003\u0007A)A\u0002yIEBq!a@\u0001A\u0003&A)A\bqgN+'O^3s)\"\u0014X-\u00193!Q\u0011\tiPa\u0001\u0011\u00075\u0011)!C\u0002\u0003\b9\u0011\u0011\u0002\u001e:b]NLWM\u001c;\t\u0017\t-\u0001\u00011AA\u0002\u0013%\u0011\u0011_\u0001\u0012aN\u001c6\r[3ek2,'\u000f\u00165sK\u0006$\u0007b\u0003B\b\u0001\u0001\u0007\t\u0019!C\u0005\u0005#\tQ\u0003]:TG\",G-\u001e7feRC'/Z1e?\u0012*\u0017\u000fF\u0002@\u0005'A\u0011\"a?\u0003\u000e\u0005\u0005\t\u0019\u0001#\t\u000f\t]\u0001\u0001)Q\u0005\t\u0006\u0011\u0002o]*dQ\u0016$W\u000f\\3s)\"\u0014X-\u00193!Q\u0011\u0011)Ba\u0001")
/* loaded from: input_file:org/apache/mxnet/spark/MXNet.class */
public class MXNet implements Serializable {
    private final Logger org$apache$mxnet$spark$MXNet$$logger = LoggerFactory.getLogger(MXNet.class);
    private final MXNetParams org$apache$mxnet$spark$MXNet$$params = new MXNetParams();
    private transient MXNetControllingThread psServerThread;
    private transient MXNetControllingThread psSchedulerThread;

    /* compiled from: MXNet.scala */
    /* loaded from: input_file:org/apache/mxnet/spark/MXNet$MXNetControllingThread.class */
    public class MXNetControllingThread extends Thread {
        private final String schedulerIP;
        private final int schedulerPort;
        private final SparkContext sparkContext;
        private final Function3<String, Object, SparkContext, BoxedUnit> triggerOfComponent;
        public final /* synthetic */ MXNet $outer;

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            this.triggerOfComponent.apply(this.schedulerIP, BoxesRunTime.boxToInteger(this.schedulerPort), this.sparkContext);
        }

        public /* synthetic */ MXNet org$apache$mxnet$spark$MXNet$MXNetControllingThread$$$outer() {
            return this.$outer;
        }

        public MXNetControllingThread(MXNet mXNet, String str, int i, SparkContext sparkContext, Function3<String, Object, SparkContext, BoxedUnit> function3) {
            this.schedulerIP = str;
            this.schedulerPort = i;
            this.sparkContext = sparkContext;
            this.triggerOfComponent = function3;
            if (mXNet == null) {
                throw null;
            }
            this.$outer = mXNet;
        }
    }

    public Logger org$apache$mxnet$spark$MXNet$$logger() {
        return this.org$apache$mxnet$spark$MXNet$$logger;
    }

    public MXNetParams org$apache$mxnet$spark$MXNet$$params() {
        return this.org$apache$mxnet$spark$MXNet$$params;
    }

    private MXNetControllingThread psServerThread() {
        return this.psServerThread;
    }

    private void psServerThread_$eq(MXNetControllingThread mXNetControllingThread) {
        this.psServerThread = mXNetControllingThread;
    }

    private MXNetControllingThread psSchedulerThread() {
        return this.psSchedulerThread;
    }

    private void psSchedulerThread_$eq(MXNetControllingThread mXNetControllingThread) {
        this.psSchedulerThread = mXNetControllingThread;
    }

    public MXNet setBatchSize(int i) {
        org$apache$mxnet$spark$MXNet$$params().batchSize_$eq(i);
        return this;
    }

    public MXNet setNumEpoch(int i) {
        org$apache$mxnet$spark$MXNet$$params().numEpoch_$eq(i);
        return this;
    }

    public MXNet setDimension(Shape shape) {
        org$apache$mxnet$spark$MXNet$$params().dimension_$eq(shape);
        return this;
    }

    public MXNet setNetwork(Symbol symbol) {
        org$apache$mxnet$spark$MXNet$$params().setNetwork(symbol);
        return this;
    }

    public MXNet setContext(Context[] contextArr) {
        org$apache$mxnet$spark$MXNet$$params().context_$eq(contextArr);
        return this;
    }

    public MXNet setNumWorker(int i) {
        org$apache$mxnet$spark$MXNet$$params().numWorker_$eq(i);
        return this;
    }

    public MXNet setNumServer(int i) {
        org$apache$mxnet$spark$MXNet$$params().numServer_$eq(i);
        return this;
    }

    public MXNet setDataName(String str) {
        org$apache$mxnet$spark$MXNet$$params().dataName_$eq(str);
        return this;
    }

    public MXNet setLabelName(String str) {
        org$apache$mxnet$spark$MXNet$$params().labelName_$eq(str);
        return this;
    }

    public MXNet setTimeout(int i) {
        org$apache$mxnet$spark$MXNet$$params().timeout_$eq(i);
        return this;
    }

    public MXNet setExecutorJars(String str) {
        org$apache$mxnet$spark$MXNet$$params().jars_$eq(str.split(",|:"));
        return this;
    }

    public MXNet setJava(String str) {
        org$apache$mxnet$spark$MXNet$$params().javabin_$eq(str);
        return this;
    }

    private void startPSServers(String str, int i, SparkContext sparkContext) {
        psServerThread_$eq(new MXNetControllingThread(this, str, i, sparkContext, new MXNet$$anonfun$startPSServers$1(this)));
        psServerThread().start();
    }

    private void startPSScheduler(String str, int i, SparkContext sparkContext) {
        psSchedulerThread_$eq(new MXNetControllingThread(this, str, i, sparkContext, new MXNet$$anonfun$startPSScheduler$1(this)));
        psSchedulerThread().start();
    }

    public FeedForward org$apache$mxnet$spark$MXNet$$setFeedForwardModel(Optimizer optimizer, int i, KVStore kVStore, LabeledPointIter labeledPointIter) {
        org$apache$mxnet$spark$MXNet$$logger().debug("Define model");
        Context[] context = org$apache$mxnet$spark$MXNet$$params().context();
        FeedForward feedForward = new FeedForward(org$apache$mxnet$spark$MXNet$$params().getNetwork(), context, org$apache$mxnet$spark$MXNet$$params().numEpoch(), (i / org$apache$mxnet$spark$MXNet$$params().batchSize()) / kVStore.numWorkers(), optimizer, new Xavier(Xavier$.MODULE$.$lessinit$greater$default$1(), "in", 2.34f), FeedForward$.MODULE$.$lessinit$greater$default$7(), (Map<String, NDArray>) null, (Map<String, NDArray>) null, FeedForward$.MODULE$.$lessinit$greater$default$10(), 0);
        org$apache$mxnet$spark$MXNet$$logger().info("Start training ...");
        feedForward.fit(labeledPointIter, (DataIter) null, new Accuracy(), kVStore);
        return feedForward;
    }

    public KVStore org$apache$mxnet$spark$MXNet$$setupKVStore(String str, int i) {
        KVStoreServer$.MODULE$.init(ParameterServer$.MODULE$.buildEnv("worker", str, i, org$apache$mxnet$spark$MXNet$$params().numServer(), org$apache$mxnet$spark$MXNet$$params().numWorker()));
        KVStore create = KVStore$.MODULE$.create("dist_async");
        create.setBarrierBeforeExit(false);
        return create;
    }

    public void org$apache$mxnet$spark$MXNet$$reclaimResources(LabeledPointIter labeledPointIter, KVStore kVStore) {
        labeledPointIter.dispose();
        kVStore.setBarrierBeforeExit(true);
        kVStore.dispose();
    }

    private MXNetModel trainModel(RDD<LabeledPoint> rdd, String str, int i) {
        RDD cache = rdd.mapPartitions(new MXNet$$anonfun$1(this, str, i), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(MXNetModel.class)).cache();
        cache.foreachPartition(new MXNet$$anonfun$trainModel$1(this));
        return (MXNetModel) cache.first();
    }

    public MXNetModel fit(RDD<LabeledPoint> rdd) {
        RDD<LabeledPoint> rdd2;
        SparkContext context = rdd.context();
        if (org$apache$mxnet$spark$MXNet$$params().jars() != null) {
            Predef$.MODULE$.refArrayOps(org$apache$mxnet$spark$MXNet$$params().jars()).foreach(new MXNet$$anonfun$fit$1(this, context));
        }
        if (org$apache$mxnet$spark$MXNet$$params().numWorker() != rdd.partitions().length) {
            org$apache$mxnet$spark$MXNet$$logger().info("repartitioning training set to {} partitions", BoxesRunTime.boxToInteger(org$apache$mxnet$spark$MXNet$$params().numWorker()));
            int numWorker = org$apache$mxnet$spark$MXNet$$params().numWorker();
            rdd2 = rdd.repartition(numWorker, rdd.repartition$default$2(numWorker));
        } else {
            rdd2 = rdd;
        }
        RDD<LabeledPoint> rdd3 = rdd2;
        String ipAddress = Network$.MODULE$.ipAddress();
        int availablePort = Network$.MODULE$.availablePort();
        startPSScheduler(ipAddress, availablePort, context);
        startPSServers(ipAddress, availablePort, context);
        MXNetModel trainModel = trainModel(rdd3, ipAddress, availablePort);
        org$apache$mxnet$spark$MXNet$$logger().info("Waiting for scheduler ...");
        psSchedulerThread().join();
        psServerThread().join();
        return trainModel;
    }

    public final void org$apache$mxnet$spark$MXNet$$startPSServersInner$1(String str, int i, SparkContext sparkContext) {
        sparkContext.parallelize(RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), org$apache$mxnet$spark$MXNet$$params().numServer()), org$apache$mxnet$spark$MXNet$$params().numServer(), ClassTag$.MODULE$.Int()).foreachPartition(new MXNet$$anonfun$org$apache$mxnet$spark$MXNet$$startPSServersInner$1$1(this, str, i));
    }

    public final void org$apache$mxnet$spark$MXNet$$startPSSchedulerInner$1(String str, int i, SparkContext sparkContext) {
        org$apache$mxnet$spark$MXNet$$logger().info("Starting scheduler on {}:{}", str, BoxesRunTime.boxToInteger(i));
        int startProcess = new ParameterServer(org$apache$mxnet$spark$MXNet$$params().runtimeClasspath(), "scheduler", str, i, org$apache$mxnet$spark$MXNet$$params().numServer(), org$apache$mxnet$spark$MXNet$$params().numWorker(), org$apache$mxnet$spark$MXNet$$params().timeout(), org$apache$mxnet$spark$MXNet$$params().javabin(), ParameterServer$.MODULE$.$lessinit$greater$default$9()).startProcess();
        Predef$.MODULE$.require(startProcess == 0, new MXNet$$anonfun$org$apache$mxnet$spark$MXNet$$startPSSchedulerInner$1$1(this, startProcess));
    }
}
