/*
 * Decompiled with CFR 0.152.
 */
package net.sansa_stack.spark.rdd.op.rdf;

import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.sansa_stack.spark.rdd.op.rdf.JavaRddOps;
import org.aksw.commons.lambda.serializable.SerializableFunction;
import org.aksw.commons.util.function.TriConsumer;
import org.aksw.commons.util.stream.StreamFunction;
import org.aksw.jenax.stmt.core.SparqlStmt;
import org.aksw.jenax.stmt.core.SparqlStmtQuery;
import org.aksw.jenax.stmt.util.SparqlStmtUtils;
import org.apache.jena.atlas.iterator.Iter;
import org.apache.jena.graph.Node;
import org.apache.jena.graph.Triple;
import org.apache.jena.query.ARQ;
import org.apache.jena.query.Dataset;
import org.apache.jena.query.DatasetFactory;
import org.apache.jena.query.Query;
import org.apache.jena.sparql.ARQConstants;
import org.apache.jena.sparql.algebra.Algebra;
import org.apache.jena.sparql.algebra.Op;
import org.apache.jena.sparql.algebra.Transform;
import org.apache.jena.sparql.algebra.Transformer;
import org.apache.jena.sparql.algebra.optimize.TransformExtendCombine;
import org.apache.jena.sparql.algebra.optimize.TransformJoinStrategy;
import org.apache.jena.sparql.algebra.optimize.TransformPropertyFunction;
import org.apache.jena.sparql.core.DatasetGraph;
import org.apache.jena.sparql.core.DatasetGraphFactory;
import org.apache.jena.sparql.core.Quad;
import org.apache.jena.sparql.core.Transactional;
import org.apache.jena.sparql.core.Var;
import org.apache.jena.sparql.engine.ExecutionContext;
import org.apache.jena.sparql.engine.QueryIterator;
import org.apache.jena.sparql.engine.binding.Binding;
import org.apache.jena.sparql.engine.binding.BindingFactory;
import org.apache.jena.sparql.engine.main.OpExecutorFactory;
import org.apache.jena.sparql.engine.main.QC;
import org.apache.jena.sparql.exec.UpdateExec;
import org.apache.jena.sparql.expr.NodeValue;
import org.apache.jena.sparql.modify.TemplateLib;
import org.apache.jena.sparql.syntax.Template;
import org.apache.jena.sparql.util.Context;
import org.apache.jena.sparql.util.NodeFactoryExtra;
import org.apache.jena.system.Txn;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;

public class JavaRddOfBindingsOps {
    public static final Var ROWNUM = Var.alloc((String)"ROWNUM");

    public static <T> BiFunction<Binding, ExecutionContext, Stream<T>> compileNodeTupleMapper(Query query, java.util.function.Function<Template, java.util.function.Function<Binding, Stream<T>>> templateMapperFactory) {
        Preconditions.checkArgument((boolean)query.isConstructType(), (Object)"Construct query expected");
        Template template = query.getConstructTemplate();
        Op op0 = Algebra.compile((Query)query);
        Op op1 = TransformPropertyFunction.transform((Op)op0, (Context)ARQ.getContext());
        Op op2 = Transformer.transform((Transform)new TransformJoinStrategy(), (Op)op1);
        Op finalOp = JavaRddOfBindingsOps.tarqlOptimize(op2);
        java.util.function.Function templateMapper = templateMapperFactory.apply(template);
        return (binding, execCxt) -> {
            QueryIterator r = QC.execute((Op)finalOp, (Binding)binding, (ExecutionContext)execCxt);
            Stream base = Iter.asStream((Iterator)r);
            return base.flatMap(templateMapper);
        };
    }

    public static <I, O> java.util.function.Function<I, O> bindToExecCxt(ExecutionContext execCxt, BiFunction<I, ExecutionContext, O> fn) {
        return JavaRddOfBindingsOps.bindSecondArgument(execCxt, fn);
    }

    public static <A2, I, O> java.util.function.Function<I, O> bindSecondArgument(A2 arg2, BiFunction<I, A2, O> fn) {
        return in -> fn.apply(in, arg2);
    }

    public static java.util.function.Function<Binding, Stream<Triple>> templateMapperTriples(Template template) {
        List triples = template.getTriples();
        return binding -> Iter.asStream((Iterator)TemplateLib.calcTriples((List)triples, Collections.singleton(binding).iterator()));
    }

    public static java.util.function.Function<Binding, Stream<Quad>> templateMapperQuads(Template template) {
        List quads = template.getQuads();
        return binding -> Iter.asStream((Iterator)TemplateLib.calcQuads((List)quads, Collections.singleton(binding).iterator()));
    }

    public static java.util.function.Function<Binding, Stream<Quad>> compileTarqlMapper(List<SparqlStmt> stmts, boolean constructMode) {
        java.util.function.Function<Binding, DatasetGraph> base = JavaRddOfBindingsOps.compileTarqlMapperGeneral(stmts, constructMode);
        return binding -> Iter.asStream((Iterator)((DatasetGraph)base.apply((Binding)binding)).find());
    }

    public static java.util.function.Function<Binding, DatasetGraph> compileTarqlMapperGeneral(Collection<SparqlStmt> stmts, boolean accumulationMode) {
        List actions = stmts.stream().map(stmt -> {
            if (!stmt.isQuery()) return (b, execCxt, outDs) -> UpdateExec.dataset((DatasetGraph)execCxt.getDataset()).substitution(b).update(stmt.getUpdateRequest()).execute();
            Query query = stmt.getQuery();
            if (query.isConstructType()) {
                BiFunction ntm = JavaRddOfBindingsOps.compileNodeTupleMapper(stmt.getQuery(), JavaRddOfBindingsOps::templateMapperQuads);
                return (b, execCxt, outDs) -> ((Stream)ntm.apply((Binding)b, (ExecutionContext)execCxt)).forEach(arg_0 -> ((DatasetGraph)outDs).add(arg_0));
            }
            if (query.isSelectType()) {
                throw new UnsupportedOperationException();
            }
            if (!query.isAskType()) throw new IllegalStateException("Unknown query type: " + query);
            throw new UnsupportedOperationException();
        }).collect(Collectors.toList());
        Context context = ARQ.getContext().copy();
        OpExecutorFactory opExecutorFactory = QC.getFactory((Context)context);
        context.set(ARQConstants.sysCurrentTime, (Object)NodeFactoryExtra.nowAsDateTime());
        java.util.function.Function<Binding, DatasetGraph> result = binding -> {
            DatasetGraph r = DatasetGraphFactory.createGeneral();
            DatasetGraph inputDs = DatasetGraphFactory.createGeneral();
            ExecutionContext execCxt = new ExecutionContext(context, inputDs.getDefaultGraph(), inputDs, opExecutorFactory);
            Txn.executeWrite((Transactional)r, () -> Txn.executeWrite((Transactional)inputDs, () -> {
                for (TriConsumer action : actions) {
                    action.accept(binding, (Object)execCxt, (Object)r);
                    if (!accumulationMode) continue;
                    inputDs.addAll(r);
                }
            }));
            return accumulationMode ? inputDs : r;
        };
        return result;
    }

    public static boolean mayProduceQuads(Collection<SparqlStmt> stmts) {
        return stmts.stream().anyMatch(JavaRddOfBindingsOps::mayProduceQuads);
    }

    public static boolean mayProduceQuads(SparqlStmt stmt) {
        Query query;
        boolean result = !stmt.isParsed() ? true : (stmt.isQuery() ? !(query = stmt.getQuery()).isConstructType() || !query.isConstructQuad() : true);
        return result;
    }

    public static JavaRDD<Dataset> tarqlDatasets(JavaRDD<Binding> rdd, Query query) {
        return JavaRddOfBindingsOps.tarqlDatasets(rdd, Collections.singleton(new SparqlStmtQuery(query)), false, (SerializableFunction & Serializable)dsg -> Stream.of(DatasetFactory.wrap((DatasetGraph)dsg)));
    }

    public static <T> JavaRDD<T> tarqlDatasets(JavaRDD<Binding> rdd, Collection<SparqlStmt> stmts, boolean accumulationMode, SerializableFunction<DatasetGraph, Stream<T>> finisher) {
        boolean usesRowNum = JavaRddOfBindingsOps.mentionesRowNum(stmts);
        rdd = usesRowNum ? JavaRddOfBindingsOps.enrichRddWithRowNum(rdd) : rdd;
        return JavaRddOps.mapPartitions(rdd, (StreamFunction & Serializable)upstream -> {
            java.util.function.Function<Binding, DatasetGraph> mapper = JavaRddOfBindingsOps.compileTarqlMapperGeneral(stmts, accumulationMode);
            return upstream.map(mapper).flatMap(dg -> (Stream)finisher.apply(dg));
        });
    }

    public static JavaRDD<Triple> tarqlTriples(JavaRDD<Binding> rdd, Collection<SparqlStmt> stmts, boolean accumulationMode, Supplier<ExecutionContext> execCxtSupplier) {
        boolean allQueries = stmts.stream().allMatch(SparqlStmt::isQuery);
        boolean canUseFastTrack = allQueries && (!accumulationMode || stmts.size() < 2);
        boolean usesRowNum = JavaRddOfBindingsOps.mentionesRowNum(stmts);
        rdd = usesRowNum ? JavaRddOfBindingsOps.enrichRddWithRowNum(rdd) : rdd;
        Object result = canUseFastTrack ? JavaRddOps.mapPartitions(rdd, (StreamFunction & Serializable)bindings -> {
            List<Query> queries = stmts.stream().map(SparqlStmt::getQuery).collect(Collectors.toList());
            StreamFunction<Binding, Triple> mapper = JavaRddOfBindingsOps.tripleMapper(queries, execCxtSupplier);
            return (Stream)mapper.apply(bindings);
        }) : JavaRddOfBindingsOps.tarqlDatasets(rdd, stmts, accumulationMode, (SerializableFunction & Serializable)dg -> Iter.asStream((Iterator)dg.find()).map(Quad::asTriple));
        return result;
    }

    public static JavaRDD<Quad> tarqlQuads(JavaRDD<Binding> rdd, Query query, Supplier<ExecutionContext> execCxtSupplier) {
        return JavaRddOfBindingsOps.tarqlQuads(rdd, Collections.singleton(new SparqlStmtQuery(query)), false, execCxtSupplier);
    }

    public static JavaRDD<Quad> tarqlQuads(JavaRDD<Binding> rdd, Collection<SparqlStmt> stmts, boolean accumulationMode, Supplier<ExecutionContext> execCxtSupplier) {
        boolean allQueries = stmts.stream().allMatch(SparqlStmt::isQuery);
        boolean canUseFastTrack = allQueries && (!accumulationMode || stmts.size() < 2);
        boolean usesRowNum = JavaRddOfBindingsOps.mentionesRowNum(stmts);
        rdd = usesRowNum ? JavaRddOfBindingsOps.enrichRddWithRowNum(rdd) : rdd;
        Object result = canUseFastTrack ? JavaRddOps.mapPartitions(rdd, (StreamFunction & Serializable)bindings -> {
            List<Query> queries = stmts.stream().map(SparqlStmt::getQuery).collect(Collectors.toList());
            StreamFunction<Binding, Quad> mapper = JavaRddOfBindingsOps.quadMapper(queries, execCxtSupplier);
            return (Stream)mapper.apply(bindings);
        }) : JavaRddOfBindingsOps.tarqlDatasets(rdd, stmts, accumulationMode, (SerializableFunction & Serializable)dg -> Iter.asStream((Iterator)dg.find()));
        return result;
    }

    public static StreamFunction<Binding, Triple> tripleMapper(Collection<Query> queries, Supplier<ExecutionContext> execCxtSupplier) {
        ExecutionContext execCxt = execCxtSupplier.get();
        List mappers = queries.stream().map(q -> JavaRddOfBindingsOps.bindToExecCxt(execCxt, JavaRddOfBindingsOps.compileNodeTupleMapper(q, JavaRddOfBindingsOps::templateMapperTriples))).collect(Collectors.toList());
        return (StreamFunction & Serializable)upstream -> upstream.flatMap(binding -> mappers.stream().flatMap(mapper -> (Stream)mapper.apply(binding)));
    }

    public static StreamFunction<Binding, Quad> quadMapper(Collection<Query> queries, Supplier<ExecutionContext> execCxtSupplier) {
        ExecutionContext execCxt = execCxtSupplier.get();
        List mappers = queries.stream().map(q -> JavaRddOfBindingsOps.bindToExecCxt(execCxt, JavaRddOfBindingsOps.compileNodeTupleMapper(q, JavaRddOfBindingsOps::templateMapperQuads))).collect(Collectors.toList());
        return (StreamFunction & Serializable)upstream -> upstream.flatMap(binding -> mappers.stream().flatMap(mapper -> {
            Stream r = (Stream)mapper.apply(binding);
            return r;
        }));
    }

    public static Op tarqlOptimize(Op op) {
        Op result = Transformer.transform((Transform)new TransformExtendCombine(), (Op)op);
        return result;
    }

    public static boolean mentionesRowNum(SparqlStmt sparqlStmt) {
        Set nodes = SparqlStmtUtils.mentionedNodes((SparqlStmt)sparqlStmt);
        boolean result = nodes.contains(ROWNUM);
        return result;
    }

    public static boolean mentionesRowNum(Collection<SparqlStmt> sparqlStmts) {
        boolean result = sparqlStmts.stream().anyMatch(JavaRddOfBindingsOps::mentionesRowNum);
        return result;
    }

    public static JavaRDD<Binding> enrichRddWithRowNum(JavaRDD<Binding> rdd) {
        return rdd.zipWithIndex().map((Function & Serializable)bi -> BindingFactory.binding((Binding)((Binding)bi._1), (Var)ROWNUM, (Node)NodeValue.makeInteger((long)((Long)bi._2 + 1L)).asNode()));
    }
}

