package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.spi.Plugin;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.RowNumberSymbolMatcher;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoRowNumber.class */
public class TestPushdownFilterIntoRowNumber extends BaseRuleTest {
    public TestPushdownFilterIntoRowNumber() {
        super(new Plugin[0]);
    }

    @Test
    public void testSourceRowNumber() {
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            return planBuilder.filter(PlanBuilder.expression("row_number_1 < cast(100 as bigint)"), planBuilder.rowNumber(ImmutableList.of(symbol), Optional.empty(), planBuilder.symbol("row_number_1"), planBuilder.values(symbol)));
        }).matches(PlanMatchPattern.rowNumber(builder -> {
            builder.maxRowCountPerPartition(Optional.of(99)).partitionBy(ImmutableList.of("a"));
        }, PlanMatchPattern.values("a")));
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("a");
            return planBuilder2.filter(PlanBuilder.expression("row_number_1 < cast(100 as bigint)"), planBuilder2.rowNumber(ImmutableList.of(symbol), Optional.of(10), planBuilder2.symbol("row_number_1"), planBuilder2.values(symbol)));
        }).matches(PlanMatchPattern.rowNumber(builder2 -> {
            builder2.maxRowCountPerPartition(Optional.of(10)).partitionBy(ImmutableList.of("a"));
        }, PlanMatchPattern.values("a")));
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder3 -> {
            Symbol symbol = planBuilder3.symbol("a");
            return planBuilder3.filter(PlanBuilder.expression("cast(3 as bigint) < row_number_1 and row_number_1 < cast(5 as bigint)"), planBuilder3.rowNumber(ImmutableList.of(symbol), Optional.of(10), planBuilder3.symbol("row_number_1"), planBuilder3.values(symbol)));
        }).matches(PlanMatchPattern.filter("cast(3 as bigint) < row_number_1 and row_number_1 < cast(5 as bigint)", PlanMatchPattern.rowNumber(builder3 -> {
            builder3.maxRowCountPerPartition(Optional.of(4)).partitionBy(ImmutableList.of("a"));
        }, PlanMatchPattern.values("a")).withAlias("row_number_1", new RowNumberSymbolMatcher())));
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder4 -> {
            Symbol symbol = planBuilder4.symbol("a");
            return planBuilder4.filter(PlanBuilder.expression("row_number_1 < cast(5 as bigint) and a = 1"), planBuilder4.rowNumber(ImmutableList.of(symbol), Optional.of(10), planBuilder4.symbol("row_number_1"), planBuilder4.values(symbol)));
        }).matches(PlanMatchPattern.filter("a = 1", PlanMatchPattern.rowNumber(builder4 -> {
            builder4.maxRowCountPerPartition(Optional.of(4)).partitionBy(ImmutableList.of("a"));
        }, PlanMatchPattern.values("a")).withAlias("row_number_1", new RowNumberSymbolMatcher())));
    }

    @Test
    public void testNoOutputsThroughRowNumber() {
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("row_number_1 < cast(-100 as bigint)"), planBuilder.rowNumber(ImmutableList.of(planBuilder.symbol("a")), Optional.empty(), planBuilder.symbol("row_number_1"), planBuilder.values(planBuilder.symbol("a"))));
        }).matches(PlanMatchPattern.values("a", "row_number_1"));
    }

    @Test
    public void testDoNotFire() {
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("not_row_number < cast(100 as bigint)"), planBuilder.rowNumber(ImmutableList.of(planBuilder.symbol("a")), Optional.empty(), planBuilder.symbol("row_number_1"), planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("not_row_number"))));
        }).doesNotFire();
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder2 -> {
            return planBuilder2.filter(PlanBuilder.expression("row_number_1 > cast(100 as bigint)"), planBuilder2.rowNumber(ImmutableList.of(planBuilder2.symbol("a")), Optional.empty(), planBuilder2.symbol("row_number_1"), planBuilder2.values(planBuilder2.symbol("a"))));
        }).doesNotFire();
        tester().assertThat(new PushdownFilterIntoRowNumber(tester().getMetadata(), new TypeOperators())).on(planBuilder3 -> {
            Symbol symbol = planBuilder3.symbol("a");
            return planBuilder3.filter(PlanBuilder.expression("cast(3 as bigint) < row_number_1 and row_number_1 < cast(5 as bigint)"), planBuilder3.rowNumber(ImmutableList.of(symbol), Optional.of(4), planBuilder3.symbol("row_number_1"), planBuilder3.values(symbol)));
        }).doesNotFire();
    }
}
