package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.connector.CatalogName;
import io.trino.connector.MockConnectorColumnHandle;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.connector.MockConnectorTransactionHandle;
import io.trino.metadata.TableHandle;
import io.trino.spi.Plugin;
import io.trino.spi.connector.AggregationApplicationResult;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AtomicIntegerAssert;
import org.assertj.core.api.AtomicReferenceAssert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushDistinctLimitIntoTableScan.class */
public class TestPushDistinctLimitIntoTableScan extends BaseRuleTest {
    private static final CatalogName TEST_CATALOG = new CatalogName("test_push_dl_catalog");
    private PushDistinctLimitIntoTableScan rule;
    private TableHandle tableHandle;
    private MockConnectorFactory.ApplyAggregation testApplyAggregation;

    public TestPushDistinctLimitIntoTableScan() {
        super(new Plugin[0]);
    }

    @Override // io.trino.sql.planner.iterative.rule.test.BaseRuleTest
    protected Optional<LocalQueryRunner> createLocalQueryRunner() {
        LocalQueryRunner create = LocalQueryRunner.create(TestingSession.testSessionBuilder().setCatalog(TEST_CATALOG.getCatalogName()).setSchema("tiny").build());
        create.createCatalog(TEST_CATALOG.getCatalogName(), MockConnectorFactory.builder().withApplyAggregation((connectorSession, connectorTableHandle, list, map, list2) -> {
            return this.testApplyAggregation != null ? this.testApplyAggregation.apply(connectorSession, connectorTableHandle, list, map, list2) : Optional.empty();
        }).build(), Map.of());
        return Optional.of(create);
    }

    @BeforeClass
    public void init() {
        this.rule = new PushDistinctLimitIntoTableScan(tester().getMetadata());
        this.tableHandle = new TableHandle(TEST_CATALOG, new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation")), MockConnectorTransactionHandle.INSTANCE, Optional.empty());
    }

    @BeforeMethod
    public void reset() {
        this.testApplyAggregation = null;
    }

    @Test
    public void testDoesNotFireIfNoTableScan() {
        tester().assertThat(this.rule).on(planBuilder -> {
            return planBuilder.values(planBuilder.symbol("a", BigintType.BIGINT));
        }).doesNotFire();
    }

    @Test
    public void testNoEffect() {
        AtomicInteger atomicInteger = new AtomicInteger();
        this.testApplyAggregation = (connectorSession, connectorTableHandle, list, map, list2) -> {
            atomicInteger.incrementAndGet();
            return Optional.empty();
        };
        tester().assertThat(this.rule).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("regionkey");
            return planBuilder.distinctLimit(10L, List.of(symbol), planBuilder.tableScan(this.tableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, new MockConnectorColumnHandle("regionkey", BigintType.BIGINT))));
        }).doesNotFire();
        ((AtomicIntegerAssert) Assertions.assertThat(atomicInteger).as("applyCallCounter", new Object[0])).hasValue(1);
    }

    @Test
    public void testPushDistinct() {
        AtomicInteger atomicInteger = new AtomicInteger();
        AtomicReference atomicReference = new AtomicReference();
        AtomicReference atomicReference2 = new AtomicReference();
        AtomicReference atomicReference3 = new AtomicReference();
        this.testApplyAggregation = (connectorSession, connectorTableHandle, list, map, list2) -> {
            atomicInteger.incrementAndGet();
            atomicReference.set(List.copyOf(list));
            atomicReference2.set(Map.copyOf(map));
            atomicReference3.set((List) list2.stream().map((v0) -> {
                return List.copyOf(v0);
            }).collect(Collectors.toUnmodifiableList()));
            return Optional.of(new AggregationApplicationResult(new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation_aggregated")), List.of(), List.of(), Map.of(), false));
        };
        MockConnectorColumnHandle mockConnectorColumnHandle = new MockConnectorColumnHandle("regionkey", BigintType.BIGINT);
        tester().assertThat(this.rule).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("regionkey_symbol");
            return planBuilder.distinctLimit(43L, List.of(symbol), planBuilder.tableScan(this.tableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, mockConnectorColumnHandle)));
        }).matches(PlanMatchPattern.limit(43L, PlanMatchPattern.project(PlanMatchPattern.tableScan("mock_nation_aggregated"))));
        ((AtomicIntegerAssert) Assertions.assertThat(atomicInteger).as("applyCallCounter", new Object[0])).hasValue(1);
        ((AtomicReferenceAssert) Assertions.assertThat(atomicReference).as("applyAggregates", new Object[0])).hasValue(List.of());
        ((AtomicReferenceAssert) Assertions.assertThat(atomicReference2).as("applyAssignments", new Object[0])).hasValue(Map.of("regionkey_symbol", mockConnectorColumnHandle));
        ((AtomicReferenceAssert) Assertions.assertThat(atomicReference3).as("applyGroupingSets", new Object[0])).hasValue(List.of(List.of(mockConnectorColumnHandle)));
    }
}
