package io.trino.cost;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/cost/TestAggregationStatsRule.class */
public class TestAggregationStatsRule extends BaseStatsCalculatorTest {
    @Test
    public void testAggregationWhenAllStatisticsAreKnown() {
        Consumer<PlanNodeStatsAssertion> consumer = planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCount(15.0d).symbolStats("z", symbolStatsAssertion -> {
                symbolStatsAssertion.lowValue(10.0d).highValue(15.0d).distinctValuesCount(4.0d).nullsFraction(0.2d);
            }).symbolStats("y", symbolStatsAssertion2 -> {
                symbolStatsAssertion2.lowValue(0.0d).highValue(3.0d).distinctValuesCount(3.0d).nullsFraction(0.0d);
            });
        };
        testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0d).setHighValue(15.0d).setDistinctValuesCount(4.0d).setNullsFraction(0.1d).build()).check(consumer);
        testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0d).setHighValue(15.0d).setDistinctValuesCount(4.0d).build()).check(consumer);
        Consumer<PlanNodeStatsAssertion> consumer2 = planNodeStatsAssertion2 -> {
            planNodeStatsAssertion2.outputRowsCountUnknown().symbolStats("z", symbolStatsAssertion -> {
                symbolStatsAssertion.lowValue(10.0d).highValue(15.0d).distinctValuesCountUnknown().nullsFractionUnknown();
            }).symbolStats("y", symbolStatsAssertion2 -> {
                symbolStatsAssertion2.lowValue(0.0d).highValue(3.0d).distinctValuesCount(3.0d).nullsFraction(0.0d);
            });
        };
        testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0d).setHighValue(15.0d).setNullsFraction(0.1d).build()).check(consumer2);
        testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0d).setHighValue(15.0d).build()).check(consumer2);
    }

    private StatsCalculatorAssertion testAggregation(SymbolStatsEstimate symbolStatsEstimate) {
        return tester().assertStatsFor(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("sum", BigintType.BIGINT), PlanBuilder.expression("sum(x)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("count", BigintType.BIGINT), PlanBuilder.expression("count()"), ImmutableList.of()).addAggregation(planBuilder.symbol("count_on_x", BigintType.BIGINT), PlanBuilder.expression("count(x)"), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder.symbol("y", BigintType.BIGINT), planBuilder.symbol("z", BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("x", BigintType.BIGINT), planBuilder.symbol("y", BigintType.BIGINT), planBuilder.symbol("z", BigintType.BIGINT)));
            });
        }).withSourceStats(PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder().setLowValue(1.0d).setHighValue(10.0d).setDistinctValuesCount(5.0d).setNullsFraction(0.3d).build()).addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setLowValue(0.0d).setHighValue(3.0d).setDistinctValuesCount(3.0d).setNullsFraction(0.0d).build()).addSymbolStatistics(new Symbol("z"), symbolStatsEstimate).build()).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.symbolStats("sum", symbolStatsAssertion -> {
                symbolStatsAssertion.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown();
            }).symbolStats("count", symbolStatsAssertion2 -> {
                symbolStatsAssertion2.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown();
            }).symbolStats("count_on_x", symbolStatsAssertion3 -> {
                symbolStatsAssertion3.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown();
            }).symbolStats("x", symbolStatsAssertion4 -> {
                symbolStatsAssertion4.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown();
            });
        });
    }

    @Test
    public void testAggregationStatsCappedToInputRows() {
        tester().assertStatsFor(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_on_x", BigintType.BIGINT), PlanBuilder.expression("count(x)"), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder.symbol("y", BigintType.BIGINT), planBuilder.symbol("z", BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("x", BigintType.BIGINT), planBuilder.symbol("y", BigintType.BIGINT), planBuilder.symbol("z", BigintType.BIGINT)));
            });
        }).withSourceStats(PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0d).build()).addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0d).build()).build()).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCount(100.0d);
        });
    }

    @Test
    public void testAggregationWithGlobalGrouping() {
        tester().assertStatsFor(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_on_x", BigintType.BIGINT), PlanBuilder.expression("count(x)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("sum", BigintType.BIGINT), PlanBuilder.expression("sum(x)"), ImmutableList.of(BigintType.BIGINT)).globalGrouping().source(planBuilder.values(planBuilder.symbol("x", BigintType.BIGINT), planBuilder.symbol("y", BigintType.BIGINT), planBuilder.symbol("z", BigintType.BIGINT)));
            });
        }).withSourceStats(PlanNodeStatsEstimate.unknown()).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCount(1.0d);
        });
    }

    @Test
    public void testAggregationWithMoreGroupingSets() {
        tester().assertStatsFor(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_on_x", BigintType.BIGINT), PlanBuilder.expression("count(x)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("sum", BigintType.BIGINT), PlanBuilder.expression("sum(x)"), ImmutableList.of(BigintType.BIGINT)).groupingSets(new AggregationNode.GroupingSetDescriptor(ImmutableList.of(planBuilder.symbol("y"), planBuilder.symbol("z")), 3, ImmutableSet.of(0))).source(planBuilder.values(planBuilder.symbol("x", BigintType.BIGINT), planBuilder.symbol("y", BigintType.BIGINT), planBuilder.symbol("z", BigintType.BIGINT)));
            });
        }).withSourceStats(PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0d).build()).addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0d).build()).build()).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCountUnknown();
        });
    }
}
