/*
 * Decompiled with CFR 0.152.
 */
package net.sansa_stack.ml.spark.kernel;

import net.sansa_stack.ml.spark.kernel.RDFFastTreeGraphKernelUtil$;
import org.apache.jena.graph.Triple;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.immutable.List$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.Symbols;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;

public final class RDFFastTreeGraphKernelUtil$ {
    public static final RDFFastTreeGraphKernelUtil$ MODULE$;

    static {
        new RDFFastTreeGraphKernelUtil$();
    }

    public Dataset<Row> triplesToDF(SparkSession sparkSession, RDD<Triple> triples, String subjectColName, String predicateColName, String objectColName) {
        JavaUniverse $u = package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Net_sansa_stack_ml_spark_kernel_RDFFastTreeGraphKernelUtil$$typecreator5$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $u.internal().reificationSupport().TypeRef($u.internal().reificationSupport().ThisType($m.staticPackage("scala").asModule().moduleClass()), (Symbols.SymbolApi)$m.staticClass("scala.Tuple3"), List$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Types.TypeApi[]{$m.staticClass("java.lang.String").asType().toTypeConstructor(), $m.staticClass("java.lang.String").asType().toTypeConstructor(), $m.staticClass("java.lang.String").asType().toTypeConstructor()})));
            }

            public Net_sansa_stack_ml_spark_kernel_RDFFastTreeGraphKernelUtil$$typecreator5$1() {
            }
        }
        return sparkSession.implicits().rddToDatasetHolder(triples.map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tuple3<String, String, String> apply(Triple f) {
                return new Tuple3((Object)f.getSubject().toString(), (Object)f.getPredicate().toString(), (Object)f.getObject().toString());
            }
        }, ClassTag$.MODULE$.apply(Tuple3.class)), sparkSession.implicits().newProductEncoder(((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Net_sansa_stack_ml_spark_kernel_RDFFastTreeGraphKernelUtil$$typecreator5$1()))).toDF((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{subjectColName, predicateColName, objectColName}));
    }

    public String triplesToDF$default$3() {
        return "subject";
    }

    public String triplesToDF$default$4() {
        return "predicate";
    }

    public String triplesToDF$default$5() {
        return "object";
    }

    public Dataset<Row> getInstanceAndLabelDF(Dataset<Row> filteredTripleDF, String subjectColName, String objectColName) {
        Dataset df = filteredTripleDF.select(subjectColName, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{objectColName})).distinct();
        StringIndexerModel indexer = new StringIndexer().setInputCol(objectColName).setOutputCol("label").fit(df);
        Dataset indexedDF = indexer.transform(df).drop(objectColName).groupBy(subjectColName, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])).agg(functions$.MODULE$.max("label").as("label"), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[0])).toDF((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"instance", "label"}));
        return indexedDF;
    }

    public String getInstanceAndLabelDF$default$2() {
        return "subject";
    }

    public String getInstanceAndLabelDF$default$3() {
        return "object";
    }

    public void predictLogisticRegressionMLLIB(RDD<LabeledPoint> data, int numClasses, int maxIteration) {
        long t0 = System.nanoTime();
        data.cache();
        Predef$.MODULE$.println((Object)new Tuple2((Object)"data count", (Object)BoxesRunTime.boxToLong((long)data.count())));
        long t1 = System.nanoTime();
        DoubleRef sumOfAccuracy = DoubleRef.create((double)0.0);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), maxIteration).foreach$mVc$sp((Function1)new Serializable(data, numClasses, sumOfAccuracy){
            public static final long serialVersionUID = 0L;
            private final RDD data$1;
            private final int numClasses$1;
            private final DoubleRef sumOfAccuracy$1;

            public final void apply(int seed) {
                this.apply$mcVI$sp(seed);
            }

            public void apply$mcVI$sp(int seed) {
                Tuple2 tuple2 = RDFFastTreeGraphKernelUtil$.MODULE$.net$sansa_stack$ml$spark$kernel$RDFFastTreeGraphKernelUtil$$trainAndValidate$1(this.data$1, seed, this.numClasses$1);
                if (tuple2 != null) {
                    Tuple2 tuple22;
                    LogisticRegressionModel model = (LogisticRegressionModel)tuple2._1();
                    double accuracy = tuple2._2$mcD$sp();
                    Tuple2 tuple23 = tuple22 = new Tuple2((Object)model, (Object)BoxesRunTime.boxToDouble((double)accuracy));
                    LogisticRegressionModel model2 = (LogisticRegressionModel)tuple23._1();
                    double accuracy2 = tuple23._2$mcD$sp();
                    this.sumOfAccuracy$1.elem += accuracy2;
                    return;
                }
                throw new MatchError((Object)tuple2);
            }
            {
                this.data$1 = data$1;
                this.numClasses$1 = numClasses$1;
                this.sumOfAccuracy$1 = sumOfAccuracy$1;
            }
        });
        long t2 = System.nanoTime();
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"Average Accuracy: ").append((Object)BoxesRunTime.boxToDouble((double)(sumOfAccuracy.elem / (double)maxIteration))).toString());
        this.printTime("Feature Computation/Read", t0, t1);
        this.printTime("Model learning/testing", t1, t2);
    }

    public int predictLogisticRegressionMLLIB$default$2() {
        return 2;
    }

    public int predictLogisticRegressionMLLIB$default$3() {
        return 5;
    }

    public void printTime(String title, long t0, long t1) {
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)title).append((Object)": ").append((Object)BoxesRunTime.boxToDouble((double)((double)(t1 - t0) / 1.0E9))).append((Object)" s").toString());
    }

    public final Tuple2 net$sansa_stack$ml$spark$kernel$RDFFastTreeGraphKernelUtil$$trainAndValidate$1(RDD data, long seed, int numClasses$1) {
        RDD[] splits = data.randomSplit(new double[]{0.9, 0.1}, seed);
        RDD training2 = splits[0].cache();
        RDD validation = splits[1];
        LogisticRegressionModel model = new LogisticRegressionWithLBFGS().setNumClasses(numClasses$1).run(training2);
        RDD predictions = validation.map((Function1)new Serializable(model){
            public static final long serialVersionUID = 0L;
            private final LogisticRegressionModel model$1;

            public final Tuple2<Object, Object> apply(LabeledPoint point) {
                double prediction = this.model$1.predict(point.features());
                return new Tuple2.mcDD.sp(point.label(), prediction);
            }
            {
                this.model$1 = model$1;
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class));
        MulticlassMetrics metrics = new MulticlassMetrics(predictions);
        double accuracy = metrics.accuracy();
        return new Tuple2((Object)model, (Object)BoxesRunTime.boxToDouble((double)accuracy));
    }

    private RDFFastTreeGraphKernelUtil$() {
        MODULE$ = this;
    }
}

