package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.testing.assertions.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/optimizations/TestPlanNodeSearcher.class */
public class TestPlanNodeSearcher {
    private static final PlanBuilder BUILDER = new PlanBuilder(new PlanNodeIdAllocator(), new AbstractMockMetadata() { // from class: io.trino.sql.planner.optimizations.TestPlanNodeSearcher.1
    });

    @Test
    public void testFindAll() {
        ProjectNode project = BUILDER.project(Assignments.of(), BUILDER.values(new Symbol[0]));
        for (int i = 1; i < 10; i++) {
            project = BUILDER.project(Assignments.of(), project);
        }
        ArrayList arrayList = new ArrayList();
        ProjectNode projectNode = project;
        while (true) {
            ProjectNode projectNode2 = projectNode;
            if (!(projectNode2 instanceof ProjectNode)) {
                PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(project);
                Class<ProjectNode> cls = ProjectNode.class;
                Objects.requireNonNull(ProjectNode.class);
                Assert.assertEquals(arrayList, (List) searchFrom.where((v1) -> {
                    return r1.isInstance(v1);
                }).findAll().stream().map((v0) -> {
                    return v0.getId();
                }).collect(ImmutableList.toImmutableList()));
                return;
            }
            arrayList.add(projectNode2.getId());
            projectNode = projectNode2.getSource();
        }
    }

    @Test
    public void testFindAllMultipleSources() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 4; i++) {
            arrayList.add(BUILDER.join(JoinNode.Type.INNER, BUILDER.values(new Symbol[0]), BUILDER.values(new Symbol[0]), new JoinNode.EquiJoinClause[0]));
        }
        JoinNode join = BUILDER.join(JoinNode.Type.INNER, BUILDER.join(JoinNode.Type.INNER, (PlanNode) arrayList.get(0), (PlanNode) arrayList.get(1), new JoinNode.EquiJoinClause[0]), BUILDER.join(JoinNode.Type.INNER, (PlanNode) arrayList.get(2), (PlanNode) arrayList.get(3), new JoinNode.EquiJoinClause[0]), new JoinNode.EquiJoinClause[0]);
        ImmutableList.Builder builder = ImmutableList.builder();
        joinNodePreorder(join, builder);
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(join);
        Class<JoinNode> cls = JoinNode.class;
        Objects.requireNonNull(JoinNode.class);
        Assert.assertEquals(builder.build(), (List) searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        }).findAll().stream().map((v0) -> {
            return v0.getId();
        }).collect(ImmutableList.toImmutableList()));
    }

    private static void joinNodePreorder(PlanNode planNode, ImmutableList.Builder<PlanNodeId> builder) {
        if (planNode instanceof ValuesNode) {
            return;
        }
        if (!(planNode instanceof JoinNode)) {
            throw new IllegalArgumentException("unsupported node type: " + planNode.getClass().getSimpleName());
        }
        builder.add(planNode.getId());
        JoinNode joinNode = (JoinNode) planNode;
        joinNodePreorder(joinNode.getLeft(), builder);
        joinNodePreorder(joinNode.getRight(), builder);
    }
}
