package io.trino.benchto.driver.macro.query;

import com.google.common.base.Preconditions;
import io.trino.benchto.driver.Benchmark;
import io.trino.benchto.driver.BenchmarkExecutionException;
import io.trino.benchto.driver.Query;
import io.trino.benchto.driver.loader.QueryLoader;
import io.trino.benchto.driver.loader.SqlStatementGenerator;
import io.trino.benchto.driver.macro.MacroExecutionDriver;
import io.trino.benchto.driver.utils.QueryUtils;
import io.trino.jdbc.TrinoConnection;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;

@Component
/* loaded from: input_file:BOOT-INF/lib/benchto-driver-0.26.jar:io/trino/benchto/driver/macro/query/QueryMacroExecutionDriver.class */
public class QueryMacroExecutionDriver implements MacroExecutionDriver {
    private static final String SET_SESSION = "set session";

    @Autowired
    private ApplicationContext applicationContext;

    @Autowired
    private QueryLoader queryLoader;

    @Autowired
    private SqlStatementGenerator sqlStatementGenerator;
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) QueryMacroExecutionDriver.class);
    private static final Pattern KEY_VALUE_PATTERN = Pattern.compile("([^=]+)='??([^']+)'??");

    @Override // io.trino.benchto.driver.macro.MacroExecutionDriver
    public boolean canExecuteBenchmarkMacro(String str) {
        return str.endsWith(".sql");
    }

    @Override // io.trino.benchto.driver.macro.MacroExecutionDriver
    public void runBenchmarkMacro(String str, Optional<Benchmark> optional, Optional<Connection> optional2) {
        Preconditions.checkArgument(optional.isPresent(), "Benchmark is required to run query based macro");
        Benchmark benchmark = optional.get();
        Query loadFromFile = this.queryLoader.loadFromFile(str);
        List<String> generateQuerySqlStatement = this.sqlStatementGenerator.generateQuerySqlStatement(loadFromFile, benchmark.getNonReservedKeywordVariables());
        try {
            if (!optional2.isPresent() || loadFromFile.getProperty("datasource").isPresent()) {
                Connection connectionFor = getConnectionFor(loadFromFile.getProperty("datasource", benchmark.getDataSource()));
                try {
                    runSqlStatements(connectionFor, generateQuerySqlStatement);
                    if (connectionFor != null) {
                        connectionFor.close();
                    }
                } finally {
                }
            } else {
                runSqlStatements(optional2.get(), generateQuerySqlStatement);
            }
        } catch (SQLException e) {
            throw new BenchmarkExecutionException("Could not execute macro SQL queries for benchmark: " + benchmark, e);
        }
    }

    private void runSqlStatements(Connection connection, List<String> list) throws SQLException {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            String trim = it.next().trim();
            LOGGER.info("Executing macro query: {}", trim);
            if (trim.toLowerCase().startsWith(SET_SESSION) && connection.isWrapperFor(TrinoConnection.class)) {
                setSessionForTrino(connection, trim);
            } else {
                Statement createStatement = connection.createStatement();
                try {
                    if (createStatement.execute(trim)) {
                        ResultSet resultSet = createStatement.getResultSet();
                        try {
                            QueryUtils.fetchRows(trim, resultSet);
                            if (resultSet != null) {
                                resultSet.close();
                            }
                        } finally {
                        }
                    }
                    if (createStatement != null) {
                        createStatement.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            }
        }
    }

    private void setSessionForTrino(Connection connection, String str) {
        try {
            TrinoConnection trinoConnection = (TrinoConnection) connection.unwrap(TrinoConnection.class);
            String[] extractKeyValue = extractKeyValue(str);
            trinoConnection.setSessionProperty(extractKeyValue[0].trim(), extractKeyValue[1].trim());
        } catch (SQLException e) {
            LOGGER.error(e.getMessage());
            throw new UnsupportedOperationException(String.format("SET SESSION for non PrestoConnection [%s] is not supported", connection.getClass()));
        }
    }

    public static String[] extractKeyValue(String str) {
        Matcher matcher = KEY_VALUE_PATTERN.matcher(str.substring(SET_SESSION.length()).trim());
        Preconditions.checkState(matcher.matches(), "Unexpected SET SESSION format [%s]", str);
        return new String[]{matcher.group(1).trim(), matcher.group(2).trim()};
    }

    private Connection getConnectionFor(String str) throws SQLException {
        return ((DataSource) this.applicationContext.getBean(str, DataSource.class)).getConnection();
    }
}
