package io.trino.verifier;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.TypeLiteral;
import io.airlift.bootstrap.Bootstrap;
import io.airlift.bootstrap.LifeCycleManager;
import io.airlift.event.client.EventClient;
import io.airlift.log.Logger;
import io.trino.sql.parser.ParsingOptions;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.tree.AddColumn;
import io.trino.sql.tree.Comment;
import io.trino.sql.tree.CreateMaterializedView;
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.CreateTableAsSelect;
import io.trino.sql.tree.CreateView;
import io.trino.sql.tree.Delete;
import io.trino.sql.tree.DropColumn;
import io.trino.sql.tree.DropMaterializedView;
import io.trino.sql.tree.DropTable;
import io.trino.sql.tree.DropView;
import io.trino.sql.tree.Explain;
import io.trino.sql.tree.ExplainAnalyze;
import io.trino.sql.tree.Insert;
import io.trino.sql.tree.RefreshMaterializedView;
import io.trino.sql.tree.RenameColumn;
import io.trino.sql.tree.RenameTable;
import io.trino.sql.tree.RenameView;
import io.trino.sql.tree.ShowCatalogs;
import io.trino.sql.tree.ShowColumns;
import io.trino.sql.tree.ShowFunctions;
import io.trino.sql.tree.ShowSchemas;
import io.trino.sql.tree.ShowSession;
import io.trino.sql.tree.ShowTables;
import io.trino.sql.tree.Statement;
import io.trino.verifier.QueryRewriter;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Paths;
import java.sql.Driver;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.jdbi.v3.core.ConnectionFactory;
import org.jdbi.v3.core.Jdbi;
import org.jdbi.v3.sqlobject.SqlObjectPlugin;
import picocli.CommandLine;

@CommandLine.Command(name = "verifier", usageHelpAutoWidth = true, versionProvider = VersionProvider.class)
/* loaded from: input_file:io/trino/verifier/VerifyCommand.class */
public class VerifyCommand implements Runnable {
    private static final Logger LOG = Logger.get(VerifyCommand.class);

    @CommandLine.Option(names = {"-h", "--help"}, usageHelp = true, description = {"Show this help message and exit"})
    public boolean usageHelpRequested;

    @CommandLine.Option(names = {"--version"}, versionHelp = true, description = {"Print version information and exit"})
    public boolean versionInfoRequested;

    @CommandLine.Parameters(index = "0", paramLabel = "<file>", description = {"Configuration file"})
    public String configFilename;

    /* loaded from: input_file:io/trino/verifier/VerifyCommand$VersionProvider.class */
    public static class VersionProvider implements CommandLine.IVersionProvider {

        @CommandLine.Spec
        public CommandLine.Model.CommandSpec spec;

        public String[] getVersion() {
            return new String[]{this.spec.name() + " " + ((String) MoreObjects.firstNonNull(getClass().getPackage().getImplementationVersion(), "(version unknown)"))};
        }
    }

    @Override // java.lang.Runnable
    public void run() {
        if (this.configFilename != null) {
            System.setProperty("config", this.configFilename);
        }
        try {
            Injector initialize = new Bootstrap(ImmutableList.builder().add(new PrestoVerifierModule()).addAll(getAdditionalModules()).build()).initialize();
            try {
                try {
                    VerifierConfig verifierConfig = (VerifierConfig) initialize.getInstance(VerifierConfig.class);
                    initialize.injectMembers(this);
                    Set set = (Set) initialize.getInstance(Key.get(new TypeLiteral<Set<String>>() { // from class: io.trino.verifier.VerifyCommand.1
                    }, SupportedEventClients.class));
                    for (String str : verifierConfig.getEventClients()) {
                        Preconditions.checkArgument(set.contains(str), "Unsupported event client: %s", str);
                    }
                    Set set2 = (Set) initialize.getInstance(Key.get(new TypeLiteral<Set<EventClient>>() { // from class: io.trino.verifier.VerifyCommand.2
                    }));
                    VerifierDao verifierDao = (VerifierDao) Jdbi.create(getQueryDatabase(initialize)).installPlugin(new SqlObjectPlugin()).onDemand(VerifierDao.class);
                    ImmutableList.Builder builder = ImmutableList.builder();
                    Iterator<String> it = verifierConfig.getSuites().iterator();
                    while (it.hasNext()) {
                        builder.addAll(verifierDao.getQueriesBySuite(it.next(), verifierConfig.getMaxQueries()));
                    }
                    List<QueryPair> filterQueries = filterQueries(filterQueryTypes(new SqlParser(), verifierConfig, applyOverrides(verifierConfig, builder.build())));
                    if (verifierConfig.getShadowWrites()) {
                        Preconditions.checkArgument(!Sets.intersection(Sets.union(verifierConfig.getTestQueryTypes(), verifierConfig.getControlQueryTypes()), ImmutableSet.of(QueryType.CREATE, QueryType.MODIFY)).isEmpty(), "CREATE or MODIFY queries must be allowed in test or control to use write shadowing");
                        filterQueries = rewriteQueries(new SqlParser(), verifierConfig, filterQueries);
                    }
                    if (verifierConfig.getAdditionalJdbcDriverPath() != null) {
                        List<URL> urls = getUrls(verifierConfig.getAdditionalJdbcDriverPath());
                        URL[] urlArr = new URL[urls.size()];
                        urls.toArray(urlArr);
                        if (verifierConfig.getTestJdbcDriverName() != null) {
                            loadJdbcDriver(urlArr, verifierConfig.getTestJdbcDriverName());
                        }
                        if (verifierConfig.getControlJdbcDriverName() != null) {
                            loadJdbcDriver(urlArr, verifierConfig.getControlJdbcDriverName());
                        }
                    }
                    System.exit(new Verifier(System.out, verifierConfig, set2).run(filterQueries) > 0 ? 1 : 0);
                    ((LifeCycleManager) initialize.getInstance(LifeCycleManager.class)).stop();
                } catch (Throwable th) {
                    ((LifeCycleManager) initialize.getInstance(LifeCycleManager.class)).stop();
                    throw th;
                }
            } catch (InterruptedException | MalformedURLException e) {
                throw new RuntimeException(e);
            }
        } catch (Exception e2) {
            Throwables.throwIfUnchecked(e2);
            throw new RuntimeException(e2);
        }
    }

    private static void loadJdbcDriver(URL[] urlArr, String str) {
        try {
            URLClassLoader uRLClassLoader = new URLClassLoader(urlArr);
            try {
                DriverManager.registerDriver(new ForwardingDriver((Driver) Class.forName(str, true, uRLClassLoader).getConstructor(new Class[0]).newInstance(new Object[0])));
                uRLClassLoader.close();
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        } catch (ReflectiveOperationException | SQLException e2) {
            throw new RuntimeException(e2);
        }
    }

    private static List<URL> getUrls(String str) throws MalformedURLException {
        ImmutableList.Builder builder = ImmutableList.builder();
        File file = new File(str);
        if (!file.isDirectory()) {
            builder.add(Paths.get(str, new String[0]).toUri().toURL());
            return builder.build();
        }
        File[] listFiles = file.listFiles((file2, str2) -> {
            return str2.endsWith(".jar");
        });
        if (listFiles == null) {
            return builder.build();
        }
        for (File file3 : listFiles) {
            if (!file3.isDirectory()) {
                builder.add(Paths.get(file3.getAbsolutePath(), new String[0]).toUri().toURL());
            }
        }
        return builder.build();
    }

    protected ConnectionFactory getQueryDatabase(Injector injector) {
        VerifierConfig verifierConfig = (VerifierConfig) injector.getInstance(VerifierConfig.class);
        return () -> {
            return DriverManager.getConnection(verifierConfig.getQueryDatabase());
        };
    }

    protected List<QueryPair> filterQueries(List<QueryPair> list) {
        return list;
    }

    @VisibleForTesting
    static List<QueryPair> rewriteQueries(SqlParser sqlParser, VerifierConfig verifierConfig, List<QueryPair> list) {
        QueryRewriter queryRewriter = new QueryRewriter(sqlParser, verifierConfig.getTestGateway(), verifierConfig.getShadowTestTablePrefix(), Optional.ofNullable(verifierConfig.getTestCatalogOverride()), Optional.ofNullable(verifierConfig.getTestSchemaOverride()), Optional.ofNullable(verifierConfig.getTestUsernameOverride()), Optional.ofNullable(verifierConfig.getTestPasswordOverride()), verifierConfig.getDoublePrecision(), verifierConfig.getTestTimeout());
        QueryRewriter queryRewriter2 = new QueryRewriter(sqlParser, verifierConfig.getControlGateway(), verifierConfig.getShadowControlTablePrefix(), Optional.ofNullable(verifierConfig.getControlCatalogOverride()), Optional.ofNullable(verifierConfig.getControlSchemaOverride()), Optional.ofNullable(verifierConfig.getControlUsernameOverride()), Optional.ofNullable(verifierConfig.getControlPasswordOverride()), verifierConfig.getDoublePrecision(), verifierConfig.getControlTimeout());
        LOG.info("Rewriting %s queries using %s threads", new Object[]{Integer.valueOf(list.size()), Integer.valueOf(verifierConfig.getThreadCount())});
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(verifierConfig.getThreadCount());
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(newFixedThreadPool);
        ArrayList arrayList = new ArrayList();
        for (QueryPair queryPair : list) {
            executorCompletionService.submit(() -> {
                try {
                    return Optional.of(new QueryPair(queryPair.getSuite(), queryPair.getName(), queryRewriter.shadowQuery(queryPair.getTest()), queryRewriter2.shadowQuery(queryPair.getControl())));
                } catch (QueryRewriter.QueryRewriteException | SQLException e) {
                    if (!verifierConfig.isQuiet()) {
                        LOG.warn(e, "Failed to rewrite %s for shadowing. Skipping.", new Object[]{queryPair.getName()});
                    }
                    return Optional.empty();
                }
            });
        }
        newFixedThreadPool.shutdown();
        try {
            Stopwatch createStarted = Stopwatch.createStarted();
            for (int i = 1; i <= list.size(); i++) {
                Optional optional = (Optional) executorCompletionService.take().get();
                Objects.requireNonNull(arrayList);
                optional.ifPresent((v1) -> {
                    r1.add(v1);
                });
                if (!verifierConfig.isQuiet() && createStarted.elapsed(TimeUnit.MINUTES) > 0) {
                    createStarted.reset().start();
                    LOG.info("Rewrite progress: %s valid, %s skipped, %.2f%% done", new Object[]{Integer.valueOf(arrayList.size()), Integer.valueOf(i - arrayList.size()), Double.valueOf((i / list.size()) * 100.0d)});
                }
            }
            LOG.info("Rewrote %s queries into %s queries", new Object[]{Integer.valueOf(list.size()), Integer.valueOf(arrayList.size())});
            return arrayList;
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("Query rewriting failed", e);
        }
    }

    private static List<QueryPair> filterQueryTypes(SqlParser sqlParser, VerifierConfig verifierConfig, List<QueryPair> list) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (QueryPair queryPair : list) {
            if (queryTypeAllowed(sqlParser, verifierConfig.getControlQueryTypes(), queryPair.getControl()) && queryTypeAllowed(sqlParser, verifierConfig.getTestQueryTypes(), queryPair.getTest())) {
                builder.add(queryPair);
            }
        }
        return builder.build();
    }

    private static boolean queryTypeAllowed(SqlParser sqlParser, Set<QueryType> set, Query query) {
        EnumSet noneOf = EnumSet.noneOf(QueryType.class);
        try {
            Iterator<String> it = query.getPreQueries().iterator();
            while (it.hasNext()) {
                noneOf.add(statementToQueryType(sqlParser, it.next()));
            }
            noneOf.add(statementToQueryType(sqlParser, query.getQuery()));
            Iterator<String> it2 = query.getPostQueries().iterator();
            while (it2.hasNext()) {
                noneOf.add(statementToQueryType(sqlParser, it2.next()));
            }
            return set.containsAll(noneOf);
        } catch (UnsupportedOperationException e) {
            return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static QueryType statementToQueryType(SqlParser sqlParser, String str) {
        try {
            return statementToQueryType(sqlParser.createStatement(str, new ParsingOptions(ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)));
        } catch (RuntimeException e) {
            throw new UnsupportedOperationException();
        }
    }

    private static QueryType statementToQueryType(Statement statement) {
        if (statement instanceof AddColumn) {
            return QueryType.MODIFY;
        }
        if (!(statement instanceof CreateTable) && !(statement instanceof CreateTableAsSelect)) {
            if (statement instanceof CreateView) {
                return ((CreateView) statement).isReplace() ? QueryType.MODIFY : QueryType.CREATE;
            }
            if (statement instanceof CreateMaterializedView) {
                return ((CreateMaterializedView) statement).isReplace() ? QueryType.MODIFY : QueryType.CREATE;
            }
            if (!(statement instanceof RefreshMaterializedView) && !(statement instanceof DropMaterializedView) && !(statement instanceof Delete) && !(statement instanceof DropTable) && !(statement instanceof DropView)) {
                if (statement instanceof Explain) {
                    return QueryType.READ;
                }
                if (statement instanceof ExplainAnalyze) {
                    return statementToQueryType(((ExplainAnalyze) statement).getStatement());
                }
                if (statement instanceof Insert) {
                    return QueryType.MODIFY;
                }
                if (statement instanceof io.trino.sql.tree.Query) {
                    return QueryType.READ;
                }
                if (!(statement instanceof RenameColumn) && !(statement instanceof DropColumn) && !(statement instanceof RenameTable) && !(statement instanceof RenameView) && !(statement instanceof Comment)) {
                    if (!(statement instanceof ShowCatalogs) && !(statement instanceof ShowColumns) && !(statement instanceof ShowFunctions) && !(statement instanceof ShowSchemas) && !(statement instanceof ShowSession) && !(statement instanceof ShowTables)) {
                        throw new UnsupportedOperationException();
                    }
                    return QueryType.READ;
                }
                return QueryType.MODIFY;
            }
            return QueryType.MODIFY;
        }
        return QueryType.CREATE;
    }

    protected Iterable<Module> getAdditionalModules() {
        return ImmutableList.of();
    }

    private static List<QueryPair> applyOverrides(VerifierConfig verifierConfig, List<QueryPair> list) {
        return (List) list.stream().map(queryPair -> {
            return new QueryPair(queryPair.getSuite(), queryPair.getName(), new Query((String) Optional.ofNullable(verifierConfig.getTestCatalogOverride()).orElse(queryPair.getTest().getCatalog()), (String) Optional.ofNullable(verifierConfig.getTestSchemaOverride()).orElse(queryPair.getTest().getSchema()), queryPair.getTest().getPreQueries(), queryPair.getTest().getQuery(), queryPair.getTest().getPostQueries(), (String) Optional.ofNullable(verifierConfig.getTestUsernameOverride()).orElse(queryPair.getTest().getUsername()), (String) Optional.ofNullable(verifierConfig.getTestPasswordOverride()).orElse((String) Optional.ofNullable(queryPair.getTest().getPassword()).orElse(null)), queryPair.getTest().getSessionProperties()), new Query((String) Optional.ofNullable(verifierConfig.getControlCatalogOverride()).orElse(queryPair.getControl().getCatalog()), (String) Optional.ofNullable(verifierConfig.getControlSchemaOverride()).orElse(queryPair.getControl().getSchema()), queryPair.getControl().getPreQueries(), queryPair.getControl().getQuery(), queryPair.getControl().getPostQueries(), (String) Optional.ofNullable(verifierConfig.getControlUsernameOverride()).orElse(queryPair.getControl().getUsername()), (String) Optional.ofNullable(verifierConfig.getControlPasswordOverride()).orElse((String) Optional.ofNullable(queryPair.getControl().getPassword()).orElse(null)), queryPair.getControl().getSessionProperties()));
        }).collect(ImmutableList.toImmutableList());
    }
}
