package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.connector.CatalogName;
import io.trino.metadata.TableHandle;
import io.trino.plugin.tpch.TpchColumnHandle;
import io.trino.plugin.tpch.TpchTableHandle;
import io.trino.plugin.tpch.TpchTransactionHandle;
import io.trino.spi.Plugin;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.sql.planner.FunctionCallBuilder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.class */
public class TestPruneCountAggregationOverScalar extends BaseRuleTest {
    public TestPruneCountAggregationOverScalar() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoesNotFireOnNonNestedAggregate() {
        tester().assertThat(new PruneCountAggregationOverScalar(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("count_1", BigintType.BIGINT), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("count")).build(), ImmutableList.of()).source(planBuilder.tableScan((List<Symbol>) ImmutableList.of(), (Map<Symbol, ColumnHandle>) ImmutableMap.of()));
            });
        }).doesNotFire();
    }

    @Test
    public void testFiresOnNestedCountAggregate() {
        tester().assertThat(new PruneCountAggregationOverScalar(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_1", BigintType.BIGINT), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("count")).build(), ImmutableList.of()).globalGrouping().step(AggregationNode.Step.SINGLE).source(planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.source(planBuilder.tableScan((List<Symbol>) ImmutableList.of(), (Map<Symbol, ColumnHandle>) ImmutableMap.of())).globalGrouping().step(AggregationNode.Step.SINGLE);
                }));
            });
        }).matches(PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("count_1", 0)));
    }

    @Test
    public void testFiresOnCountAggregateOverValues() {
        tester().assertThat(new PruneCountAggregationOverScalar(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_1", BigintType.BIGINT), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("count")).build(), ImmutableList.of()).step(AggregationNode.Step.SINGLE).globalGrouping().source(planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("orderkey")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("1"))));
            });
        }).matches(PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("count_1", 0)));
    }

    @Test
    public void testFiresOnCountAggregateOverEnforceSingleRow() {
        tester().assertThat(new PruneCountAggregationOverScalar(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_1", BigintType.BIGINT), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("count")).build(), ImmutableList.of()).step(AggregationNode.Step.SINGLE).globalGrouping().source(planBuilder.enforceSingleRow(planBuilder.tableScan((List<Symbol>) ImmutableList.of(), (Map<Symbol, ColumnHandle>) ImmutableMap.of())));
            });
        }).matches(PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("count_1", 0)));
    }

    @Test
    public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy() {
        tester().assertThat(new PruneCountAggregationOverScalar(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_1", BigintType.BIGINT), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("count")).build(), ImmutableList.of()).step(AggregationNode.Step.SINGLE).globalGrouping().source(planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.source(planBuilder.tableScan((List<Symbol>) ImmutableList.of(), (Map<Symbol, ColumnHandle>) ImmutableMap.of())).groupingSets(AggregationNode.singleGroupingSet(ImmutableList.of(planBuilder.symbol("orderkey"))));
                }));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireOnNestedNonCountAggregate() {
        tester().assertThat(new PruneCountAggregationOverScalar(tester().getMetadata())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("total_price", DoubleType.DOUBLE);
            AggregationNode aggregation = planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(symbol, new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("sum")).addArgument(DoubleType.DOUBLE, new SymbolReference("totalprice")).build(), ImmutableList.of(DoubleType.DOUBLE)).globalGrouping().source(planBuilder.project(Assignments.of(symbol, symbol.toSymbolReference()), planBuilder.tableScan(new TableHandle(new CatalogName(RuleTester.CATALOG_ID), new TpchTableHandle("orders", 0.01d), TpchTransactionHandle.INSTANCE, Optional.empty()), ImmutableList.of(symbol), ImmutableMap.of(symbol, new TpchColumnHandle(symbol.getName(), DoubleType.DOUBLE)))));
            });
            return planBuilder.aggregation(aggregationBuilder2 -> {
                aggregationBuilder2.addAggregation(planBuilder.symbol("sum_outer", DoubleType.DOUBLE), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of("sum")).addArgument(DoubleType.DOUBLE, new SymbolReference("sum_inner")).build(), ImmutableList.of(DoubleType.DOUBLE)).globalGrouping().source(aggregation);
            });
        }).doesNotFire();
    }
}
