package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.connector.MockConnectorColumnHandle;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.TestTableScanNodePartitioning;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/optimizations/TestDeterminePartitionCount.class */
public class TestDeterminePartitionCount extends BasePlanTest {
    @Override // io.trino.sql.planner.assertions.BasePlanTest
    protected LocalQueryRunner createLocalQueryRunner() {
        MockConnectorFactory build = MockConnectorFactory.builder().withGetTableHandle((connectorSession, schemaTableName) -> {
            if (schemaTableName.getTableName().equals("table_with_stats_a") || schemaTableName.getTableName().equals("table_with_stats_b") || schemaTableName.getTableName().equals("table_without_stats_a") || schemaTableName.getTableName().equals("table_without_stats_b")) {
                return new MockConnectorTableHandle(schemaTableName);
            }
            return null;
        }).withGetColumns(schemaTableName2 -> {
            return ImmutableList.of(new ColumnMetadata(TestTableScanNodePartitioning.COLUMN_A, VarcharType.VARCHAR), new ColumnMetadata(TestTableScanNodePartitioning.COLUMN_B, VarcharType.VARCHAR));
        }).withGetTableStatistics(schemaTableName3 -> {
            return (schemaTableName3.getTableName().equals("table_with_stats_a") || schemaTableName3.getTableName().equals("table_with_stats_b")) ? new TableStatistics(Estimate.of(200.0d), ImmutableMap.of(new MockConnectorColumnHandle(TestTableScanNodePartitioning.COLUMN_A, VarcharType.VARCHAR), new ColumnStatistics(Estimate.of(0.0d), Estimate.of(10000.0d), Estimate.of(DataSize.of(100L, DataSize.Unit.MEGABYTE).toBytes()), Optional.empty()), new MockConnectorColumnHandle(TestTableScanNodePartitioning.COLUMN_B, VarcharType.VARCHAR), new ColumnStatistics(Estimate.of(0.0d), Estimate.of(10000.0d), Estimate.of(DataSize.of(100L, DataSize.Unit.MEGABYTE).toBytes()), Optional.empty()))) : TableStatistics.empty();
        }).withName("mock").build();
        LocalQueryRunner create = LocalQueryRunner.create(TestingSession.testSessionBuilder().setCatalog("mock").setSchema("default").build());
        create.createCatalog("mock", build, ImmutableMap.of());
        return create;
    }

    @Test
    public void testPlanWhenTableStatisticsArePresent() {
        assertDistributedPlan("SELECT count(column_a) FROM table_with_stats_a\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "10").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(5), PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))))));
    }

    @Test
    public void testPlanWhenTableStatisticsAreAbsent() {
        assertDistributedPlan("SELECT * FROM table_without_stats_a as a JOIN table_without_stats_b as b ON a.column_a = b.column_a\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "10").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.equiCriteria(TestTableScanNodePartitioning.COLUMN_A, "column_a_0").right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.project(PlanMatchPattern.tableScan("table_without_stats_b", ImmutableMap.of("column_a_0", TestTableScanNodePartitioning.COLUMN_A, "column_b_1", TestTableScanNodePartitioning.COLUMN_B)))))).left(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.project(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.tableScan("table_without_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_B, TestTableScanNodePartitioning.COLUMN_B))))));
        })));
    }

    @Test
    public void testPlanWhenCrossJoinIsPresent() {
        assertDistributedPlan("SELECT * FROM table_with_stats_a CROSS JOIN table_with_stats_b\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "10").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_a_0", TestTableScanNodePartitioning.COLUMN_A, "column_b_1", TestTableScanNodePartitioning.COLUMN_B))))).left(PlanMatchPattern.tableScan("table_with_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_B, TestTableScanNodePartitioning.COLUMN_B)));
        })));
    }

    @Test
    public void testPlanWhenCrossJoinIsScalar() {
        assertDistributedPlan("SELECT * FROM table_with_stats_a CROSS JOIN (select max(column_a) from table_with_stats_b) t(a)\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "20").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(15), PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(15), PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))))))).left(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]));
        })));
    }

    @Test
    public void testPlanWhenJoinNodeStatsAreAbsent() {
        assertDistributedPlan("SELECT * FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_b = b.column_b\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "10").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.equiCriteria(TestTableScanNodePartitioning.COLUMN_B, "column_b_1").right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.project(PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_a_0", TestTableScanNodePartitioning.COLUMN_A, "column_b_1", TestTableScanNodePartitioning.COLUMN_B)))))).left(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.project(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.tableScan("table_with_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_B, TestTableScanNodePartitioning.COLUMN_B))))));
        })));
    }

    @Test
    public void testPlanWhenJoinNodeOutputIsBiggerThanRowsScanned() {
        assertDistributedPlan("SELECT a.column_a FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_a = b.column_a\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "50").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.equiCriteria(TestTableScanNodePartitioning.COLUMN_A, "column_a_0").right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(10), PlanMatchPattern.project(PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_a_0", TestTableScanNodePartitioning.COLUMN_A)))))).left(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(10), PlanMatchPattern.project(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.tableScan("table_with_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A))))));
        })));
    }

    @Test
    public void testEstimatedPartitionCountShouldNotBeGreaterThanMaxLimit() {
        assertDistributedPlan("SELECT * FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_a = b.column_a\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "5").setSystemProperty("min_hash_partition_count", "2").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.equiCriteria(TestTableScanNodePartitioning.COLUMN_A, "column_a_0").right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.project(PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_a_0", TestTableScanNodePartitioning.COLUMN_A, "column_b_1", TestTableScanNodePartitioning.COLUMN_B)))))).left(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.empty(), PlanMatchPattern.project(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.tableScan("table_with_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_B, TestTableScanNodePartitioning.COLUMN_B))))));
        })));
    }

    @Test
    public void testEstimatedPartitionCountShouldNotBeLessThanMinLimit() {
        assertDistributedPlan("SELECT a.column_a FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_a = b.column_a\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "20").setSystemProperty("min_hash_partition_count", "15").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.equiCriteria(TestTableScanNodePartitioning.COLUMN_A, "column_a_0").right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(15), PlanMatchPattern.project(PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_a_0", TestTableScanNodePartitioning.COLUMN_A)))))).left(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(15), PlanMatchPattern.project(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.tableScan("table_with_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A))))));
        })));
    }

    @Test
    public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput() {
        assertDistributedPlan("SELECT a.column_b\nFROM table_with_stats_a as a\nJOIN table_with_stats_b as b\nON a.column_a = b.column_a\nUNION ALL\nSELECT column_b\nFROM table_with_stats_b\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "50").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "400").build(), PlanMatchPattern.output(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(20), PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.equiCriteria(TestTableScanNodePartitioning.COLUMN_A, "column_a_1").right(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(20), PlanMatchPattern.project(PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_a_1", TestTableScanNodePartitioning.COLUMN_A)))))).left(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(20), PlanMatchPattern.project(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.tableScan("table_with_stats_a", ImmutableMap.of(TestTableScanNodePartitioning.COLUMN_A, TestTableScanNodePartitioning.COLUMN_A, "column_b_0", TestTableScanNodePartitioning.COLUMN_B))))));
        }), PlanMatchPattern.tableScan("table_with_stats_b", ImmutableMap.of("column_b_4", TestTableScanNodePartitioning.COLUMN_B)))));
    }

    @Test
    public void testPlanWhenEstimatedPartitionCountBasedOnRowsIsMoreThanOutputSize() {
        assertDistributedPlan("SELECT count(column_a) FROM table_with_stats_a\n", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("max_hash_partition_count", "100").setSystemProperty("min_hash_partition_count", "4").setSystemProperty("min_input_size_per_task", "20MB").setSystemProperty("min_input_rows_per_task", "20").build(), PlanMatchPattern.output(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, (Optional<Integer>) Optional.of(10), PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))))));
    }
}
