/*
 * Decompiled with CFR 0.152.
 */
package io.squashql.query.database;

import io.squashql.SparkDatastore;
import io.squashql.SparkUtil;
import io.squashql.query.database.AQueryEngine;
import io.squashql.query.database.DatabaseQuery;
import io.squashql.query.database.DefaultQueryRewriter;
import io.squashql.query.database.QueryRewriter;
import io.squashql.store.Datastore;
import io.squashql.table.ColumnarTable;
import io.squashql.table.RowTable;
import io.squashql.table.Table;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.sql.Dataset;
import org.eclipse.collections.api.tuple.Pair;

public class SparkQueryEngine
extends AQueryEngine<SparkDatastore> {
    public static final List<String> SUPPORTED_AGGREGATION_FUNCTIONS = List.of("any", "avg", "corr", "count", "covar_pop", "covar_samp", "min", "max", "stddev_pop", "stddev_samp", "sum", "var_pop", "var_samp", "variance");

    public SparkQueryEngine(SparkDatastore datastore) {
        super((Datastore)datastore, (QueryRewriter)DefaultQueryRewriter.INSTANCE);
    }

    protected Table retrieveAggregates(DatabaseQuery query, String sql) {
        Dataset ds = ((SparkDatastore)this.datastore).spark.sql(sql);
        Pair result = SparkQueryEngine.transformToColumnFormat((DatabaseQuery)query, Arrays.stream(ds.schema().fields()).toList(), (column, name) -> name, (column, name) -> SparkUtil.datatypeToClass(column.dataType()), (Iterator)ds.toLocalIterator(), (i, r) -> r.get(i.intValue()), (QueryRewriter)this.queryRewriter);
        return new ColumnarTable((List)result.getOne(), new HashSet(query.measures), (List)result.getTwo());
    }

    public Table executeRawSql(String sql) {
        Dataset ds = ((SparkDatastore)this.datastore).spark.sql(sql);
        Pair result = SparkQueryEngine.transformToRowFormat(Arrays.stream(ds.schema().fields()).toList(), c -> c.name(), c -> SparkUtil.datatypeToClass(c.dataType()), (Iterator)ds.toLocalIterator(), (i, r) -> r.get(i.intValue()));
        return new RowTable((List)result.getOne(), (List)result.getTwo());
    }

    public List<String> supportedAggregationFunctions() {
        return SUPPORTED_AGGREGATION_FUNCTIONS;
    }
}

