package io.trino.plugin.thrift.integration;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.io.Closer;
import io.airlift.drift.server.DriftServer;
import io.trino.Session;
import io.trino.connector.CatalogName;
import io.trino.cost.ScalarStatsCalculator;
import io.trino.metadata.TableHandle;
import io.trino.plugin.thrift.ThriftColumnHandle;
import io.trino.plugin.thrift.ThriftPlugin;
import io.trino.plugin.thrift.ThriftTableHandle;
import io.trino.plugin.thrift.ThriftTransactionHandle;
import io.trino.spi.Plugin;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.PruneTableScanColumns;
import io.trino.sql.planner.iterative.rule.PushProjectionIntoTableScan;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.RuleAssert;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.tree.SymbolReference;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/plugin/thrift/integration/TestThriftProjectionPushdown.class */
public class TestThriftProjectionPushdown extends BaseRuleTest {
    private List<DriftServer> servers;
    private static final String TINY_SCHEMA = "tiny";
    private static final ThriftTableHandle NATION_THRIFT_TABLE = new ThriftTableHandle(new SchemaTableName(TINY_SCHEMA, "nation"));
    private static final String CATALOG = "test";
    private static final TableHandle NATION_TABLE = new TableHandle(new CatalogName(CATALOG), NATION_THRIFT_TABLE, ThriftTransactionHandle.INSTANCE);
    private static final Session SESSION = TestingSession.testSessionBuilder().setCatalog(CATALOG).setSchema(TINY_SCHEMA).build();

    public TestThriftProjectionPushdown() {
        super(new Plugin[0]);
    }

    protected Optional<LocalQueryRunner> createLocalQueryRunner() {
        try {
            this.servers = ThriftQueryRunner.startThriftServers(1, false);
            ImmutableMap buildOrThrow = ImmutableMap.builder().put("trino.thrift.client.addresses", (String) this.servers.stream().map(driftServer -> {
                return "localhost:" + ThriftQueryRunner.driftServerPort(driftServer);
            }).collect(Collectors.joining(","))).put("trino.thrift.client.connect-timeout", "30s").put("trino-thrift.lookup-requests-concurrency", "2").buildOrThrow();
            LocalQueryRunner create = LocalQueryRunner.create(SESSION);
            create.createCatalog(CATALOG, (ConnectorFactory) Iterables.getOnlyElement(new ThriftPlugin().getConnectorFactories()), buildOrThrow);
            return Optional.of(create);
        } catch (Throwable th) {
            try {
                cleanup();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @AfterClass
    public void cleanup() {
        if (this.servers != null) {
            try {
                Closer create = Closer.create();
                try {
                    for (DriftServer driftServer : this.servers) {
                        Objects.requireNonNull(driftServer);
                        create.register(driftServer::shutdown);
                    }
                    if (create != null) {
                        create.close();
                    }
                    this.servers = null;
                } finally {
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    @Test
    public void testDoesNotFire() {
        PushProjectionIntoTableScan pushProjectionIntoTableScan = new PushProjectionIntoTableScan(tester().getPlannerContext(), tester().getTypeAnalyzer(), new ScalarStatsCalculator(tester().getPlannerContext(), tester().getTypeAnalyzer()));
        String str = "orderstatus";
        ThriftColumnHandle thriftColumnHandle = new ThriftColumnHandle("orderstatus", VarcharType.VARCHAR, "", false);
        ThriftTableHandle thriftTableHandle = new ThriftTableHandle(TINY_SCHEMA, "nation", TupleDomain.all(), Optional.of(ImmutableSet.of(thriftColumnHandle)));
        tester().assertThat(pushProjectionIntoTableScan).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol(str, VarcharType.VARCHAR);
            return planBuilder.project(Assignments.of(planBuilder.symbol("expr_2", VarcharType.VARCHAR), symbol.toSymbolReference()), planBuilder.tableScan(new TableHandle(new CatalogName(CATALOG), thriftTableHandle, ThriftTransactionHandle.INSTANCE), ImmutableList.of(symbol), ImmutableMap.of(symbol, thriftColumnHandle)));
        }).doesNotFire();
    }

    @Test
    public void testProjectionPushdown() {
        PushProjectionIntoTableScan pushProjectionIntoTableScan = new PushProjectionIntoTableScan(tester().getPlannerContext(), tester().getTypeAnalyzer(), new ScalarStatsCalculator(tester().getPlannerContext(), tester().getTypeAnalyzer()));
        TableHandle tableHandle = NATION_TABLE;
        String str = "orderstatus";
        ThriftColumnHandle thriftColumnHandle = new ThriftColumnHandle("orderstatus", VarcharType.VARCHAR, "", false);
        ThriftTableHandle thriftTableHandle = new ThriftTableHandle(TINY_SCHEMA, "nation", TupleDomain.all(), Optional.of(ImmutableSet.of(thriftColumnHandle)));
        RuleAssert withSession = tester().assertThat(pushProjectionIntoTableScan).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol(str, VarcharType.VARCHAR);
            return planBuilder.project(Assignments.of(planBuilder.symbol("expr_2", VarcharType.VARCHAR), symbol.toSymbolReference()), planBuilder.tableScan(tableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, thriftColumnHandle)));
        }).withSession(SESSION);
        ImmutableMap of = ImmutableMap.of("expr_2", PlanMatchPattern.expression(new SymbolReference("orderstatus")));
        Objects.requireNonNull(thriftTableHandle);
        Predicate predicate = (v1) -> {
            return r2.equals(v1);
        };
        TupleDomain all = TupleDomain.all();
        Objects.requireNonNull(thriftColumnHandle);
        withSession.matches(PlanMatchPattern.project(of, PlanMatchPattern.tableScan(predicate, all, ImmutableMap.of("orderstatus", (v1) -> {
            return r5.equals(v1);
        }))));
    }

    @Test
    public void testPruneColumns() {
        PruneTableScanColumns pruneTableScanColumns = new PruneTableScanColumns(tester().getMetadata());
        ThriftColumnHandle thriftColumnHandle = new ThriftColumnHandle("nationKey", VarcharType.VARCHAR, "", false);
        ThriftColumnHandle thriftColumnHandle2 = new ThriftColumnHandle("name", VarcharType.VARCHAR, "", false);
        RuleAssert withSession = tester().assertThat(pruneTableScanColumns).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol(thriftColumnHandle.getColumnName(), VarcharType.VARCHAR);
            Symbol symbol2 = planBuilder.symbol(thriftColumnHandle2.getColumnName(), VarcharType.VARCHAR);
            return planBuilder.project(Assignments.of(planBuilder.symbol("expr", VarcharType.VARCHAR), symbol.toSymbolReference()), planBuilder.tableScan(NATION_TABLE, ImmutableList.of(symbol, symbol2), ImmutableMap.builder().put(symbol, thriftColumnHandle).put(symbol2, thriftColumnHandle2).buildOrThrow()));
        }).withSession(SESSION);
        ImmutableMap of = ImmutableMap.of("expr", PlanMatchPattern.expression(new SymbolReference(thriftColumnHandle.getColumnName())));
        ThriftTableHandle thriftTableHandle = new ThriftTableHandle(TINY_SCHEMA, "nation", TupleDomain.all(), Optional.of(ImmutableSet.of(thriftColumnHandle)));
        Predicate predicate = (v1) -> {
            return r2.equals(v1);
        };
        TupleDomain all = TupleDomain.all();
        String columnName = thriftColumnHandle.getColumnName();
        Objects.requireNonNull(thriftColumnHandle);
        withSession.matches(PlanMatchPattern.project(of, PlanMatchPattern.tableScan(predicate, all, ImmutableMap.of(columnName, (v1) -> {
            return r5.equals(v1);
        }))));
    }
}
