package io.trino.tempto.internal.convention.sql;

import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import io.trino.tempto.Requirement;
import io.trino.tempto.assertions.QueryAssert;
import io.trino.tempto.context.ThreadLocalTestContextHolder;
import io.trino.tempto.fulfillment.table.MutableTablesState;
import io.trino.tempto.internal.convention.ConventionBasedTest;
import io.trino.tempto.internal.convention.ConventionTestsUtils;
import io.trino.tempto.internal.convention.ProcessUtils;
import io.trino.tempto.internal.convention.SqlQueryDescriptor;
import io.trino.tempto.internal.convention.SqlResultDescriptor;
import io.trino.tempto.query.QueryExecutor;
import io.trino.tempto.query.QueryResult;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.io.FilenameUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/trino/tempto/internal/convention/sql/SqlQueryConventionBasedTest.class */
public class SqlQueryConventionBasedTest extends ConventionBasedTest {
    private static final Logger LOGGER = LoggerFactory.getLogger(SqlQueryConventionBasedTest.class);
    private static final Splitter QUERY_SPLITTER = Splitter.onPattern("[;][ ]*\r?\n");
    private final Optional<Path> beforeScriptPath;
    private final Optional<Path> afterScriptPath;
    private final Path queryFile;
    private final String testNamePrefix;
    private final int testNumber;
    private final int queriesCount;
    private final SqlQueryDescriptor queryDescriptor;
    private final SqlResultDescriptor resultDescriptor;
    private final Requirement requirement;

    public SqlQueryConventionBasedTest(Optional<Path> optional, Optional<Path> optional2, Path path, String str, int i, int i2, SqlQueryDescriptor sqlQueryDescriptor, SqlResultDescriptor sqlResultDescriptor, Requirement requirement) {
        this.beforeScriptPath = optional;
        this.afterScriptPath = optional2;
        this.queryFile = path;
        this.testNamePrefix = str;
        this.testNumber = i;
        this.queriesCount = i2;
        this.queryDescriptor = sqlQueryDescriptor;
        this.resultDescriptor = sqlResultDescriptor;
        this.requirement = requirement;
    }

    @Override // io.trino.tempto.internal.convention.ConventionBasedTest
    public void test() {
        LOGGER.debug("Executing sql test: {}#{}", this.queryFile.getFileName(), this.queryDescriptor.getName());
        if (this.beforeScriptPath.isPresent()) {
            ProcessUtils.execute(this.beforeScriptPath.get().toString());
        }
        QueryAssert.assertThat(runTestQuery()).matches(this.resultDescriptor);
        if (this.afterScriptPath.isPresent()) {
            ProcessUtils.execute(this.afterScriptPath.get().toString());
        }
    }

    private QueryResult runTestQuery() {
        QueryExecutor queryExecutor = getQueryExecutor(this.queryDescriptor);
        QueryResult queryResult = null;
        List<String> splitQueries = splitQueries(this.queryDescriptor.getContent());
        Preconditions.checkState(!splitQueries.isEmpty(), "At least one query must be present");
        Iterator<String> it = splitQueries.iterator();
        while (it.hasNext()) {
            queryResult = queryExecutor.executeQuery(resolveTemplates(it.next()), new QueryExecutor.QueryParam[0]);
        }
        dumpResultsIfNeeded(queryResult);
        return queryResult;
    }

    private void dumpResultsIfNeeded(QueryResult queryResult) {
        ConventionTestsUtils.getConventionTestResultsDumpPath().ifPresent(path -> {
            try {
                dumpResults(queryResult, path);
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        });
    }

    private void dumpResults(QueryResult queryResult, Path path) throws IOException {
        if (!Files.exists(path, new LinkOption[0])) {
            Files.createDirectory(path, new FileAttribute[0]);
        }
        Preconditions.checkState(Files.isDirectory(path, new LinkOption[0]), "%s have to point to the directory", path);
        String path2 = this.queryFile.getFileName().toString();
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(Paths.get(path.toString(), path2.substring(0, path2.lastIndexOf(".")) + ".result"), new OpenOption[0]);
        try {
            newBufferedWriter.write("-- delimiter: |; types: " + ((String) queryResult.getColumnTypes().stream().map((v0) -> {
                return v0.getName();
            }).collect(Collectors.joining("|"))));
            newBufferedWriter.newLine();
            Iterator<List<?>> it = queryResult.rows().iterator();
            while (it.hasNext()) {
                newBufferedWriter.write(new QueryAssert.Row(it.next()).toString());
                newBufferedWriter.newLine();
            }
            if (newBufferedWriter != null) {
                newBufferedWriter.close();
            }
        } catch (Throwable th) {
            if (newBufferedWriter != null) {
                try {
                    newBufferedWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private String resolveTemplates(String str) {
        try {
            Template template = new Template("name", new StringReader(str), new Configuration());
            HashMap newHashMap = Maps.newHashMap();
            newHashMap.put("mutableTables", (Map) MutableTablesState.mutableTablesState().getDatabaseNames().stream().collect(Collectors.toMap(str2 -> {
                return str2;
            }, str3 -> {
                return MutableTablesState.mutableTablesState().getNameInDatabaseMap(str3);
            })));
            StringWriter stringWriter = new StringWriter();
            template.process(newHashMap, stringWriter);
            stringWriter.flush();
            return stringWriter.toString();
        } catch (TemplateException | IOException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    private List<String> splitQueries(String str) {
        return (List) Lists.newArrayList(QUERY_SPLITTER.split(str)).stream().filter(str2 -> {
            return !str2.isEmpty();
        }).collect(Collectors.toList());
    }

    @Override // io.trino.tempto.testmarkers.WithName
    public String getTestName() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.testNamePrefix);
        sb.append(".");
        sb.append(FilenameUtils.getBaseName(this.queryFile.getFileName().toString()));
        if (this.queryDescriptor.getName().isPresent()) {
            sb.append(".");
            sb.append(this.queryDescriptor.getName().get().replaceAll("\\s", ""));
        } else if (this.queriesCount > 1) {
            sb.append("_");
            sb.append(this.testNumber);
        }
        return sb.toString();
    }

    @Override // io.trino.tempto.RequirementsProvider
    public Requirement getRequirements(io.trino.tempto.configuration.Configuration configuration) {
        return this.requirement;
    }

    @Override // io.trino.tempto.testmarkers.WithTestGroups
    public Set<String> getTestGroups() {
        return this.queryDescriptor.getTestGroups();
    }

    private QueryExecutor getQueryExecutor(SqlQueryDescriptor sqlQueryDescriptor) {
        String databaseName = sqlQueryDescriptor.getDatabaseName();
        try {
            return (QueryExecutor) ThreadLocalTestContextHolder.testContext().getDependency(QueryExecutor.class, databaseName);
        } catch (RuntimeException e) {
            throw new RuntimeException("Cannot get query executor for database '" + databaseName + "'", e);
        }
    }
}
