package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.Plugin;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPruneExchangeColumns.class */
public class TestPruneExchangeColumns extends BaseRuleTest {
    public TestPruneExchangeColumns() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoNotPruneReferencedOutputSymbol() {
        tester().assertThat(new PruneExchangeColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(symbol2)).addInputsSet(symbol2).singleDistributionPartitioningScheme(symbol);
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoNotPrunePartitioningSymbol() {
        tester().assertThat(new PruneExchangeColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.of(), planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(symbol2)).addInputsSet(symbol2).fixedHashDistributionParitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoNotPruneHashSymbol() {
        tester().assertThat(new PruneExchangeColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("h");
            Symbol symbol3 = planBuilder.symbol("b");
            Symbol symbol4 = planBuilder.symbol("h_1");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(symbol3, symbol4)).addInputsSet(symbol3, symbol4).fixedHashDistributionParitioningScheme(ImmutableList.of(symbol, symbol2), ImmutableList.of(symbol), symbol2);
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoNotPruneOrderingSymbol() {
        tester().assertThat(new PruneExchangeColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.of(), planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(symbol2)).addInputsSet(symbol2).singleDistributionPartitioningScheme(symbol).orderingScheme(new OrderingScheme(ImmutableList.of(symbol), ImmutableMap.of(symbol, SortOrder.ASC_NULLS_FIRST)));
            }));
        }).doesNotFire();
    }

    @Test
    public void testPruneUnreferencedSymbol() {
        tester().assertThat(new PruneExchangeColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(symbol, symbol2)).addInputsSet(symbol, symbol2).singleDistributionPartitioningScheme(symbol, symbol2);
            }));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("a", PlanMatchPattern.expression("a")), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, (List<PlanMatchPattern.Ordering>) ImmutableList.of(), (Set<String>) ImmutableSet.of(), (Optional<List<List<String>>>) Optional.of(ImmutableList.of(ImmutableList.of("a"))), PlanMatchPattern.values((List<String>) ImmutableList.of("a", "b"))).withExactOutputs("a")));
    }

    @Test
    public void testPruneUnreferencedSymbolMultipleSources() {
        tester().assertThat(new PruneExchangeColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a_1");
            Symbol symbol2 = planBuilder.symbol("a_2");
            Symbol symbol3 = planBuilder.symbol("b_1");
            Symbol symbol4 = planBuilder.symbol("b_2");
            Symbol symbol5 = planBuilder.symbol("c_1");
            Symbol symbol6 = planBuilder.symbol("c_2");
            return planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.values(symbol3, symbol4)).addInputsSet(symbol3, symbol4).addSource(planBuilder.values(symbol5, symbol6)).addInputsSet(symbol5, symbol6).singleDistributionPartitioningScheme(symbol, symbol2);
            }));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, (List<PlanMatchPattern.Ordering>) ImmutableList.of(), (Set<String>) ImmutableSet.of(), (Optional<List<List<String>>>) Optional.of(ImmutableList.of(ImmutableList.of("b_1"), ImmutableList.of("c_1"))), PlanMatchPattern.values((List<String>) ImmutableList.of("b_1", "b_2")), PlanMatchPattern.values((List<String>) ImmutableList.of("c_1", "c_2"))).withNumberOfOutputColumns(1)));
    }
}
