package io.trino.plugin.redshift;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Streams;
import dev.failsafe.Failsafe;
import dev.failsafe.RetryPolicy;
import dev.failsafe.RetryPolicyBuilder;
import io.airlift.log.Logger;
import io.airlift.log.Logging;
import io.airlift.testing.Closeables;
import io.trino.Session;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.spi.security.Identity;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.QueryAssertions;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingSession;
import io.trino.tpch.TpchTable;
import java.time.Duration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.jdbi.v3.core.HandleCallback;
import org.jdbi.v3.core.HandleConsumer;
import org.jdbi.v3.core.Jdbi;

/* loaded from: input_file:io/trino/plugin/redshift/RedshiftQueryRunner.class */
public final class RedshiftQueryRunner {
    private static final String TEST_DATABASE = "testdb";
    private static final String TEST_CATALOG = "redshift";
    static final String TEST_SCHEMA = "test_schema";
    private static final String CONNECTOR_NAME = "redshift";
    private static final String TPCH_CATALOG = "tpch";
    private static final String GRANTED_USER = "alice";
    private static final String NON_GRANTED_USER = "bob";
    private static final Logger log = Logger.get(RedshiftQueryRunner.class);
    private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint");
    static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user");
    static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password");
    private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root");
    private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role");
    static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + "testdb";

    private RedshiftQueryRunner() {
    }

    public static DistributedQueryRunner createRedshiftQueryRunner(Map<String, String> map, Map<String, String> map2, Iterable<TpchTable<?>> iterable) throws Exception {
        return createRedshiftQueryRunner(createSession(), map, map2, iterable);
    }

    public static DistributedQueryRunner createRedshiftQueryRunner(Session session, Map<String, String> map, Map<String, String> map2, Iterable<TpchTable<?>> iterable) throws Exception {
        DistributedQueryRunner.Builder builder = DistributedQueryRunner.builder(session);
        Objects.requireNonNull(builder);
        map.forEach(builder::addExtraProperty);
        AutoCloseable build = builder.build();
        try {
            build.installPlugin(new TpchPlugin());
            build.createCatalog(TPCH_CATALOG, TPCH_CATALOG, Map.of());
            HashMap hashMap = new HashMap(map2);
            hashMap.putIfAbsent("connection-url", JDBC_URL);
            hashMap.putIfAbsent("connection-user", JDBC_USER);
            hashMap.putIfAbsent("connection-password", JDBC_PASSWORD);
            build.installPlugin(new RedshiftPlugin());
            build.createCatalog("redshift", "redshift", hashMap);
            executeInRedshift("CREATE SCHEMA IF NOT EXISTS test_schema", new Object[0]);
            createUserIfNotExists(NON_GRANTED_USER, JDBC_PASSWORD);
            createUserIfNotExists(GRANTED_USER, JDBC_PASSWORD);
            executeInRedshiftWithRetry(String.format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", TEST_DATABASE, GRANTED_USER));
            executeInRedshiftWithRetry(String.format("GRANT ALL PRIVILEGES ON SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER));
            provisionTables(session, build, iterable);
            executeInRedshiftWithRetry(String.format("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER));
            return build;
        } catch (Throwable th) {
            Closeables.closeAllSuppress(th, new AutoCloseable[]{build});
            throw th;
        }
    }

    private static Session createSession() {
        return createSession(GRANTED_USER);
    }

    private static Session createSession(String str) {
        return TestingSession.testSessionBuilder().setCatalog("redshift").setSchema(TEST_SCHEMA).setIdentity(Identity.ofUser(str)).build();
    }

    private static void createUserIfNotExists(String str, String str2) {
        try {
            executeInRedshift("CREATE USER " + str + " PASSWORD '" + str2 + "'", new Object[0]);
        } catch (Exception e) {
            if (!e.getMessage().matches(".*user \"" + str + "\" already exists.*")) {
                throw e;
            }
        }
    }

    private static void executeInRedshiftWithRetry(String str) {
        Failsafe.with(((RetryPolicyBuilder) RetryPolicy.builder().handleIf(th -> {
            return th.getMessage().matches(".* concurrent transaction .*");
        })).withDelay(Duration.ofSeconds(10L)).withMaxRetries(3).build(), new RetryPolicy[0]).run(() -> {
            executeInRedshift(str, new Object[0]);
        });
    }

    public static void executeInRedshift(String str, Object... objArr) {
        executeInRedshift(handle -> {
            handle.execute(str, objArr);
        });
    }

    public static <E extends Exception> void executeInRedshift(HandleConsumer<E> handleConsumer) throws Exception {
        executeWithRedshift(handleConsumer.asCallback());
    }

    public static <T, E extends Exception> T executeWithRedshift(HandleCallback<T, E> handleCallback) throws Exception {
        return (T) Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(handleCallback);
    }

    private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable<TpchTable<?>> iterable) {
        Set set = (Set) queryRunner.listTables(session, (String) session.getCatalog().orElseThrow(), (String) session.getSchema().orElseThrow()).stream().map((v0) -> {
            return v0.getObjectName();
        }).collect(Collectors.toUnmodifiableSet());
        Streams.stream(iterable).map(tpchTable -> {
            return tpchTable.getTableName().toLowerCase(Locale.ENGLISH);
        }).filter(str -> {
            return !set.contains(str);
        }).forEach(str2 -> {
            copyFromS3(queryRunner, session, str2);
        });
        Iterator<TpchTable<?>> it = iterable.iterator();
        while (it.hasNext()) {
            verifyLoadedDataHasSameSchema(session, queryRunner, it.next());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void copyFromS3(QueryRunner queryRunner, Session session, String str) {
        String format = String.format("%s/%s/%s.parquet", S3_TPCH_TABLES_ROOT, TPCH_CATALOG, str);
        log.info("Creating table %s in Redshift copying from %s", new Object[]{str, format});
        queryRunner.execute(session, String.format("CREATE TABLE %s.%s.%s AS ", session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), str) + String.format("SELECT * FROM %s.%s.%s WITH NO DATA", TPCH_CATALOG, "tiny", str));
        executeInRedshiftWithRetry("COPY test_schema." + str + " FROM '" + format + "' IAM_ROLE '" + IAM_ROLE + "' FORMAT PARQUET");
    }

    private static void copyFromTpchCatalog(QueryRunner queryRunner, Session session, String str) {
        QueryAssertions.copyTable(queryRunner, TPCH_CATALOG, "tiny", str, session);
    }

    private static void verifyLoadedDataHasSameSchema(Session session, QueryRunner queryRunner, TpchTable<?> tpchTable) {
        try {
            long longValue = ((Long) queryRunner.execute("SELECT count(*) FROM " + String.format("%s.%s.%s", TPCH_CATALOG, "tiny", tpchTable.getTableName())).getOnlyValue()).longValue();
            long longValue2 = ((Long) queryRunner.execute("SELECT count(*) FROM " + String.format("%s.%s.%s", session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), tpchTable.getTableName())).getOnlyValue()).longValue();
            if (longValue != longValue2) {
                throw new RuntimeException(String.format("Table %s is not loaded correctly. Expected %s rows got %s", tpchTable.getTableName(), Long.valueOf(longValue), Long.valueOf(longValue2)));
            }
            log.info("Checking column types on table %s", new Object[]{tpchTable.getTableName()});
            Assertions.assertThat(queryRunner.execute("DESCRIBE " + tpchTable.getTableName())).containsExactlyElementsOf(queryRunner.execute(String.format("DESCRIBE %s.%s.%s", TPCH_CATALOG, "tiny", tpchTable.getTableName())));
        } catch (Exception e) {
            throw new RuntimeException("Failed to assert columns for TPC-H table " + tpchTable.getTableName(), e);
        }
    }

    private static String requireSystemProperty(String str) {
        return (String) Objects.requireNonNull(System.getProperty(str), str + " is not set");
    }

    public static void main(String[] strArr) throws Exception {
        Logging.initialize();
        DistributedQueryRunner createRedshiftQueryRunner = createRedshiftQueryRunner(ImmutableMap.of("http-server.http.port", "8080"), ImmutableMap.of(), ImmutableList.of());
        log.info("======== SERVER STARTED ========");
        log.info("\n====\n%s\n====", new Object[]{createRedshiftQueryRunner.getCoordinator().getBaseUrl()});
    }
}
