/*
 * Decompiled with CFR 0.152.
 */
package io.squashql;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import com.google.common.base.Suppliers;
import io.squashql.SparkUtil;
import io.squashql.store.Datastore;
import io.squashql.store.Store;
import io.squashql.store.TypedField;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalog.Catalog;
import org.apache.spark.sql.catalog.Column;
import org.apache.spark.sql.catalog.Table;
import org.apache.spark.sql.types.DataType;
import org.slf4j.LoggerFactory;

public class SparkDatastore
implements Datastore {
    public final Supplier<Map<String, Store>> stores;
    public final SparkSession spark;

    public SparkDatastore() {
        this(SparkSession.builder().appName("Java Spark SQL Example").config("spark.master", "local").config("spark.driver.bindAddress", "127.0.0.1").getOrCreate());
    }

    public SparkDatastore(SparkSession sparkSession) {
        this.spark = sparkSession;
        this.stores = Suppliers.memoize(() -> {
            HashMap r = new HashMap();
            SparkDatastore.getTableNames(this.spark).forEach(table -> r.put(table, new Store(table, SparkDatastore.getFields(this.spark, table))));
            return r;
        });
    }

    public Map<String, Store> storesByName() {
        return this.stores.get();
    }

    public Dataset<Row> get(String storeName) {
        return this.spark.table(storeName);
    }

    public static Collection<String> getTableNames(SparkSession spark) {
        try {
            Dataset tables = spark.catalog().listTables("default");
            HashSet<String> tableNames = new HashSet<String>();
            Iterator tableIterator = tables.toLocalIterator();
            while (tableIterator.hasNext()) {
                tableNames.add(((Table)tableIterator.next()).name());
            }
            return tableNames;
        }
        catch (AnalysisException e) {
            throw new RuntimeException(e);
        }
    }

    public static List<TypedField> getFields(SparkSession spark, String tableName) {
        try {
            Catalog catalog = spark.catalog();
            Table table = catalog.getTable(tableName);
            Dataset columns = table.isTemporary() ? catalog.listColumns(tableName) : catalog.listColumns("default", tableName);
            ArrayList<TypedField> fields = new ArrayList<TypedField>();
            Iterator columnIterator = columns.toLocalIterator();
            while (columnIterator.hasNext()) {
                Column column = (Column)columnIterator.next();
                fields.add(new TypedField(tableName, column.name(), SparkUtil.datatypeToClass(DataType.fromDDL((String)column.dataType()))));
            }
            return fields;
        }
        catch (AnalysisException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        Logger root = (Logger)LoggerFactory.getLogger((String)"ROOT");
        root.setLevel(Level.INFO);
    }
}

