package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.tree.Expression;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.class */
public class TestPushAggregationThroughOuterJoin extends BaseRuleTest {
    public TestPushAggregationThroughOuterJoin() {
        super(new Plugin[0]);
    }

    @Test
    public void testPushesAggregationThroughLeftJoin() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values(planBuilder.symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL1"), planBuilder.symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(planBuilder.symbol("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COL1", PlanMatchPattern.expression("COL1"), "COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL1", "COL2")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0)), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0)))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testPushesAggregationThroughRightJoin() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.RIGHT, planBuilder.values(planBuilder.symbol("COL2")), planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL2"), planBuilder.symbol("COL1"))), ImmutableList.of(planBuilder.symbol("COL2")), ImmutableList.of(planBuilder.symbol("COL1")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(planBuilder.symbol("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)"), "COL1", PlanMatchPattern.expression("COL1")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.RIGHT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL2", "COL1")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0))), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testPushesAggregationWithMask() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values(planBuilder.symbol("COL2"), planBuilder.symbol("MASK")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL1"), planBuilder.symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL2"), planBuilder.symbol("MASK")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), (List<Type>) ImmutableList.of(DoubleType.DOUBLE), planBuilder.symbol("MASK")).singleGroupingSet(planBuilder.symbol("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COL1", PlanMatchPattern.expression("COL1"), "COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL1", "COL2")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0)), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), ImmutableList.of(), ImmutableList.of("MASK"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0, "MASK", 1)))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), ImmutableList.of(), ImmutableList.of("MASK_NULL"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0, "MASK_NULL", 1))))));
    }

    @Test
    public void testPushCountAllAggregation() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values(planBuilder.symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL1"), planBuilder.symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("COUNT"), PlanBuilder.expression("count(*)"), ImmutableList.of()).singleGroupingSet(planBuilder.symbol("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COL1", PlanMatchPattern.expression("COL1"), "COALESCE", PlanMatchPattern.expression("coalesce(COUNT, COUNT_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL1", "COL2")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0)), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("COUNT"), PlanMatchPattern.functionCall("count", ImmutableList.of())), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0)))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("COUNT_NULL"), PlanMatchPattern.functionCall("count", ImmutableList.of())), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testDoesNotFireWhenMultipleGroupingSets() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1"), planBuilder.symbol("COL2")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("1", "2"))), planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL3")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("1"))), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL1"), planBuilder.symbol("COL3"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL3")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("COUNT"), PlanBuilder.expression("count(*)"), ImmutableList.of()).groupingSets(AggregationNode.groupingSets(ImmutableList.of(planBuilder.symbol("COL1"), planBuilder.symbol("COL2")), 2, ImmutableSet.of()));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenNotDistinct() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"), PlanBuilder.expressions("11"))), planBuilder.values(new Symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1"));
            });
        }).doesNotFire();
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder2.join(JoinNode.Type.LEFT, planBuilder2.project(Assignments.builder().putIdentity(planBuilder2.symbol("COL1", BigintType.BIGINT)).build(), planBuilder2.aggregation(aggregationBuilder -> {
                    aggregationBuilder.singleGroupingSet(planBuilder2.symbol("COL1"), planBuilder2.symbol("unused")).source(planBuilder2.values((List<Symbol>) ImmutableList.of(planBuilder2.symbol("COL1"), planBuilder2.symbol("unused")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10", "1"), PlanBuilder.expressions("10", "2"))));
                })), planBuilder2.values(planBuilder2.symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder2.symbol("COL1"), planBuilder2.symbol("COL2"))), ImmutableList.of(planBuilder2.symbol("COL1")), ImmutableList.of(planBuilder2.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder2.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(planBuilder2.symbol("COL1"));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenGroupingOnInner() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values(new Symbol("COL2"), new Symbol("COL3")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1"), new Symbol("COL3"));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL2")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("20"))), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(planBuilder.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL1)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1"));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenAggregationOnMultipleSymbolsDoesNotHaveSomeSymbols() {
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("COL2"), planBuilder.symbol("COL3")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("20", "30"))), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1")), ImmutableList.of(new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("MIN_BY"), PlanBuilder.expression("min_by(COL2, COL1)"), ImmutableList.of(DoubleType.DOUBLE, DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1"));
            });
        }).doesNotFire();
        tester().assertThat(new PushAggregationThroughOuterJoin()).on(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder2.join(JoinNode.Type.LEFT, planBuilder2.values((List<Symbol>) ImmutableList.of(planBuilder2.symbol("COL1")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder2.values((List<Symbol>) ImmutableList.of(planBuilder2.symbol("COL2"), planBuilder2.symbol("COL3")), (List<List<Expression>>) ImmutableList.of(PlanBuilder.expressions("20", "30"))), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1")), ImmutableList.of(new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).addAggregation(new Symbol("MIN_BY"), PlanBuilder.expression("min_by(COL2, COL3)"), ImmutableList.of(DoubleType.DOUBLE, DoubleType.DOUBLE)).addAggregation(new Symbol("MAX_BY"), PlanBuilder.expression("max_by(COL2, COL1)"), ImmutableList.of(DoubleType.DOUBLE, DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1"));
            });
        }).doesNotFire();
    }
}
