package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.CostComparator;
import io.trino.cost.CostProvider;
import io.trino.cost.PlanCostEstimate;
import io.trino.cost.PlanNodeStatsAndCostSummary;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.EqualityInference;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReorderJoins.class */
public class ReorderJoins implements Rule<JoinNode> {
    private static final Logger log = Logger.get(ReorderJoins.class);
    private final Pattern<JoinNode> pattern;
    private final TypeAnalyzer typeAnalyzer;
    private final PlannerContext plannerContext;
    private final CostComparator costComparator;

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReorderJoins$JoinEnumerationResult.class */
    public static class JoinEnumerationResult {
        static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.unknown());
        static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.infinite());
        private final Optional<PlanNode> planNode;
        private final PlanCostEstimate cost;

        private JoinEnumerationResult(Optional<PlanNode> optional, PlanCostEstimate planCostEstimate) {
            this.planNode = (Optional) Objects.requireNonNull(optional, "planNode is null");
            this.cost = (PlanCostEstimate) Objects.requireNonNull(planCostEstimate, "cost is null");
            Preconditions.checkArgument(((planCostEstimate.hasUnknownComponents() || planCostEstimate.equals(PlanCostEstimate.infinite())) && optional.isEmpty()) || (!(planCostEstimate.hasUnknownComponents() && planCostEstimate.equals(PlanCostEstimate.infinite())) && optional.isPresent()), "planNode should be present if and only if cost is known");
        }

        public Optional<PlanNode> getPlanNode() {
            return this.planNode;
        }

        public PlanCostEstimate getCost() {
            return this.cost;
        }

        static JoinEnumerationResult createJoinEnumerationResult(Optional<PlanNode> optional, PlanCostEstimate planCostEstimate) {
            return planCostEstimate.hasUnknownComponents() ? UNKNOWN_COST_RESULT : planCostEstimate.equals(PlanCostEstimate.infinite()) ? INFINITE_COST_RESULT : new JoinEnumerationResult(optional, planCostEstimate);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReorderJoins$JoinEnumerator.class */
    public static class JoinEnumerator {
        private final Metadata metadata;
        private final Session session;
        private final StatsProvider statsProvider;
        private final CostProvider costProvider;
        private final Ordering<JoinEnumerationResult> resultComparator;
        private final PlanNodeIdAllocator idAllocator;
        private final Expression allFilter;
        private final EqualityInference allFilterInference;
        private final Lookup lookup;
        private final Rule.Context context;
        private final Map<Set<PlanNode>, JoinEnumerationResult> memo = new HashMap();

        @VisibleForTesting
        JoinEnumerator(Metadata metadata, CostComparator costComparator, Expression expression, Rule.Context context) {
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.context = (Rule.Context) Objects.requireNonNull(context);
            this.session = (Session) Objects.requireNonNull(context.getSession(), "session is null");
            this.statsProvider = (StatsProvider) Objects.requireNonNull(context.getStatsProvider(), "statsProvider is null");
            this.costProvider = (CostProvider) Objects.requireNonNull(context.getCostProvider(), "costProvider is null");
            this.resultComparator = costComparator.forSession(this.session).onResultOf(joinEnumerationResult -> {
                return joinEnumerationResult.cost;
            });
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(context.getIdAllocator(), "idAllocator is null");
            this.allFilter = (Expression) Objects.requireNonNull(expression, "filter is null");
            this.allFilterInference = EqualityInference.newInstance(metadata, expression);
            this.lookup = (Lookup) Objects.requireNonNull(context.getLookup(), "lookup is null");
        }

        private JoinEnumerationResult chooseJoinOrder(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list) {
            this.context.checkTimeoutNotExhausted();
            Set<PlanNode> copyOf = ImmutableSet.copyOf(linkedHashSet);
            JoinEnumerationResult joinEnumerationResult = this.memo.get(copyOf);
            if (joinEnumerationResult == null) {
                Preconditions.checkState(linkedHashSet.size() > 1, "sources size is less than or equal to one");
                ImmutableList.Builder builder = ImmutableList.builder();
                Iterator<Set<Integer>> it = generatePartitions(linkedHashSet.size()).iterator();
                while (it.hasNext()) {
                    JoinEnumerationResult createJoinAccordingToPartitioning = createJoinAccordingToPartitioning(linkedHashSet, list, it.next());
                    if (createJoinAccordingToPartitioning.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                        this.memo.put(copyOf, createJoinAccordingToPartitioning);
                        return createJoinAccordingToPartitioning;
                    }
                    if (!createJoinAccordingToPartitioning.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                        builder.add(createJoinAccordingToPartitioning);
                    }
                }
                ImmutableList build = builder.build();
                if (build.isEmpty()) {
                    this.memo.put(copyOf, JoinEnumerationResult.INFINITE_COST_RESULT);
                    return JoinEnumerationResult.INFINITE_COST_RESULT;
                }
                joinEnumerationResult = (JoinEnumerationResult) this.resultComparator.min(build);
                this.memo.put(copyOf, joinEnumerationResult);
            }
            joinEnumerationResult.planNode.ifPresent(planNode -> {
                ReorderJoins.log.debug("Least cost join was: %s", new Object[]{planNode});
            });
            return joinEnumerationResult;
        }

        @VisibleForTesting
        static Set<Set<Integer>> generatePartitions(int i) {
            Preconditions.checkArgument(i > 1, "totalNodes must be greater than 1");
            Set set = (Set) IntStream.range(0, i).boxed().collect(ImmutableSet.toImmutableSet());
            return (Set) Sets.powerSet(set).stream().filter(set2 -> {
                return set2.contains(0);
            }).filter(set3 -> {
                return set3.size() < set.size();
            }).collect(ImmutableSet.toImmutableSet());
        }

        @VisibleForTesting
        JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list, Set<Integer> set) {
            ImmutableList copyOf = ImmutableList.copyOf(linkedHashSet);
            Stream<Integer> stream = set.stream();
            Objects.requireNonNull(copyOf);
            LinkedHashSet<PlanNode> linkedHashSet2 = (LinkedHashSet) stream.map((v1) -> {
                return r1.get(v1);
            }).collect(Collectors.toCollection(LinkedHashSet::new));
            return createJoin(linkedHashSet2, (LinkedHashSet) linkedHashSet.stream().filter(planNode -> {
                return !linkedHashSet2.contains(planNode);
            }).collect(Collectors.toCollection(LinkedHashSet::new)), list);
        }

        private JoinEnumerationResult createJoin(LinkedHashSet<PlanNode> linkedHashSet, LinkedHashSet<PlanNode> linkedHashSet2, List<Symbol> list) {
            Set<Symbol> set = (Set) linkedHashSet.stream().flatMap(planNode -> {
                return planNode.getOutputSymbols().stream();
            }).collect(ImmutableSet.toImmutableSet());
            Set<Symbol> set2 = (Set) linkedHashSet2.stream().flatMap(planNode2 -> {
                return planNode2.getOutputSymbols().stream();
            }).collect(ImmutableSet.toImmutableSet());
            List<Expression> joinPredicates = getJoinPredicates(set, set2);
            List list2 = (List) joinPredicates.stream().filter(JoinEnumerator::isJoinEqualityCondition).map(expression -> {
                return toEquiJoinClause((ComparisonExpression) expression, set);
            }).collect(ImmutableList.toImmutableList());
            if (list2.isEmpty()) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            List list3 = (List) joinPredicates.stream().filter(expression2 -> {
                return !isJoinEqualityCondition(expression2);
            }).collect(ImmutableList.toImmutableList());
            ImmutableSet build = ImmutableSet.builder().addAll(list).addAll(SymbolsExtractor.extractUnique(joinPredicates)).build();
            Stream stream = build.stream();
            Objects.requireNonNull(set);
            JoinEnumerationResult joinSource = getJoinSource(linkedHashSet, (List) stream.filter((v1) -> {
                return r3.contains(v1);
            }).collect(ImmutableList.toImmutableList()));
            if (joinSource.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            if (joinSource.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            PlanNode orElseThrow = joinSource.planNode.orElseThrow(() -> {
                return new VerifyException("Plan node is not present");
            });
            Stream stream2 = build.stream();
            Objects.requireNonNull(set2);
            JoinEnumerationResult joinSource2 = getJoinSource(linkedHashSet2, (List) stream2.filter((v1) -> {
                return r3.contains(v1);
            }).collect(ImmutableList.toImmutableList()));
            if (joinSource2.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            if (joinSource2.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            PlanNode orElseThrow2 = joinSource2.planNode.orElseThrow(() -> {
                return new VerifyException("Plan node is not present");
            });
            Stream<Symbol> stream3 = orElseThrow.getOutputSymbols().stream();
            Objects.requireNonNull(list);
            List list4 = (List) stream3.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList());
            Stream<Symbol> stream4 = orElseThrow2.getOutputSymbols().stream();
            Objects.requireNonNull(list);
            return setJoinNodeProperties(new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.INNER, orElseThrow, orElseThrow2, list2, list4, (List) stream4.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList()), false, list3.isEmpty() ? Optional.empty() : Optional.of(ExpressionUtils.and(list3)), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty()));
        }

        private List<Expression> getJoinPredicates(Set<Symbol> set, Set<Symbol> set2) {
            ImmutableList.Builder builder = ImmutableList.builder();
            Stream filter = EqualityInference.nonInferrableConjuncts(this.metadata, this.allFilter).map(expression -> {
                return this.allFilterInference.rewrite(expression, Sets.union(set, set2));
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            }).filter(expression2 -> {
                return this.allFilterInference.rewrite(expression2, set) == null;
            }).filter(expression3 -> {
                return this.allFilterInference.rewrite(expression3, set2) == null;
            });
            Objects.requireNonNull(builder);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            builder.addAll(EqualityInference.newInstance(this.metadata, this.allFilterInference.generateEqualitiesPartitionedBy(Sets.union(set, set2)).getScopeEqualities()).generateEqualitiesPartitionedBy(set).getScopeStraddlingEqualities());
            return builder.build();
        }

        private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list) {
            if (linkedHashSet.size() != 1) {
                return chooseJoinOrder(linkedHashSet, list);
            }
            PlanNode planNode = (PlanNode) Iterables.getOnlyElement(linkedHashSet);
            Set<Symbol> copyOf = ImmutableSet.copyOf(list);
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.addAll(this.allFilterInference.generateEqualitiesPartitionedBy(copyOf).getScopeEqualities());
            Stream filter = EqualityInference.nonInferrableConjuncts(this.metadata, this.allFilter).map(expression -> {
                return this.allFilterInference.rewrite(expression, copyOf);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(builder);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            Expression combineConjuncts = ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder.build());
            if (!BooleanLiteral.TRUE_LITERAL.equals(combineConjuncts)) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, combineConjuncts);
            }
            return createJoinEnumerationResult(planNode);
        }

        private static boolean isJoinEqualityCondition(Expression expression) {
            return (expression instanceof ComparisonExpression) && ((ComparisonExpression) expression).getOperator() == ComparisonExpression.Operator.EQUAL && (((ComparisonExpression) expression).getLeft() instanceof SymbolReference) && (((ComparisonExpression) expression).getRight() instanceof SymbolReference);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static JoinNode.EquiJoinClause toEquiJoinClause(ComparisonExpression comparisonExpression, Set<Symbol> set) {
            Symbol from = Symbol.from(comparisonExpression.getLeft());
            JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(from, Symbol.from(comparisonExpression.getRight()));
            return set.contains(from) ? equiJoinClause : equiJoinClause.flip();
        }

        private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) {
            if (QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), this.lookup)) {
                return createJoinEnumerationResult(joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED));
            }
            if (QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft(), this.lookup)) {
                return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(JoinNode.DistributionType.REPLICATED));
            }
            List<JoinEnumerationResult> possibleJoinNodes = getPossibleJoinNodes(joinNode, SystemSessionProperties.getJoinDistributionType(this.session));
            Verify.verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty", new Object[0]);
            Stream<JoinEnumerationResult> stream = possibleJoinNodes.stream();
            JoinEnumerationResult joinEnumerationResult = JoinEnumerationResult.UNKNOWN_COST_RESULT;
            Objects.requireNonNull(joinEnumerationResult);
            return stream.anyMatch((v1) -> {
                return r1.equals(v1);
            }) ? JoinEnumerationResult.UNKNOWN_COST_RESULT : (JoinEnumerationResult) this.resultComparator.min(possibleJoinNodes);
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, OptimizerConfig.JoinDistributionType joinDistributionType) {
            Preconditions.checkArgument(joinNode.getType() == JoinNode.Type.INNER, "unexpected join node type: %s", joinNode.getType());
            if (joinNode.isCrossJoin()) {
                return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED);
            }
            switch (joinDistributionType) {
                case PARTITIONED:
                    return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.PARTITIONED);
                case BROADCAST:
                    return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED);
                case AUTOMATIC:
                    return ImmutableList.builder().addAll(getPossibleJoinNodes(joinNode, JoinNode.DistributionType.PARTITIONED)).addAll(getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED, joinNode2 -> {
                        return DetermineJoinDistributionType.canReplicate(joinNode2, this.context);
                    })).build();
                default:
                    throw new IncompatibleClassChangeError();
            }
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinNode.DistributionType distributionType) {
            return getPossibleJoinNodes(joinNode, distributionType, joinNode2 -> {
                return true;
            });
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinNode.DistributionType distributionType, Predicate<JoinNode> predicate) {
            return (List) ImmutableList.of(joinNode.withDistributionType(distributionType), joinNode.flipChildren().withDistributionType(distributionType)).stream().filter(predicate).map(this::createJoinEnumerationResult).collect(ImmutableList.toImmutableList());
        }

        private JoinEnumerationResult createJoinEnumerationResult(JoinNode joinNode) {
            PlanCostEstimate cost = this.costProvider.getCost(joinNode);
            PlanNodeStatsEstimate stats = this.statsProvider.getStats(joinNode);
            return JoinEnumerationResult.createJoinEnumerationResult(Optional.of(joinNode.withReorderJoinStatsAndCost(new PlanNodeStatsAndCostSummary(stats.getOutputRowCount(), stats.getOutputSizeInBytes(joinNode.getOutputSymbols(), this.context.getSymbolAllocator().getTypes()), cost.getCpuCost(), cost.getMaxMemory(), cost.getNetworkCost()))), cost);
        }

        private JoinEnumerationResult createJoinEnumerationResult(PlanNode planNode) {
            return JoinEnumerationResult.createJoinEnumerationResult(Optional.of(planNode), this.costProvider.getCost(planNode));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReorderJoins$MultiJoinNode.class */
    public static class MultiJoinNode {
        private final LinkedHashSet<PlanNode> sources;
        private final Expression filter;
        private final List<Symbol> outputSymbols;
        private final boolean pushedProjectionThroughJoin;

        /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReorderJoins$MultiJoinNode$Builder.class */
        static class Builder {
            private List<PlanNode> sources;
            private Expression filter;
            private List<Symbol> outputSymbols;

            Builder() {
            }

            public Builder setSources(PlanNode... planNodeArr) {
                this.sources = ImmutableList.copyOf(planNodeArr);
                return this;
            }

            public Builder setFilter(Expression expression) {
                this.filter = expression;
                return this;
            }

            public Builder setOutputSymbols(Symbol... symbolArr) {
                this.outputSymbols = ImmutableList.copyOf(symbolArr);
                return this;
            }

            public MultiJoinNode build() {
                return new MultiJoinNode(new LinkedHashSet(this.sources), this.filter, this.outputSymbols, false);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReorderJoins$MultiJoinNode$JoinNodeFlattener.class */
        public static class JoinNodeFlattener {
            private final PlannerContext plannerContext;
            private final Session session;
            private final TypeAnalyzer typeAnalyzer;
            private final TypeProvider types;
            private final Lookup lookup;
            private final PlanNodeIdAllocator planNodeIdAllocator;
            private final LinkedHashSet<PlanNode> sources = new LinkedHashSet<>();
            private final List<Expression> filters = new ArrayList();
            private final List<Symbol> outputSymbols;
            private final boolean pushProjectionsThroughJoin;
            private boolean pushedProjectionThroughJoin;

            JoinNodeFlattener(PlannerContext plannerContext, JoinNode joinNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, int i, boolean z, Session session, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider) {
                this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
                Objects.requireNonNull(joinNode, "node is null");
                Preconditions.checkState(joinNode.getType() == JoinNode.Type.INNER, "join type must be INNER");
                this.outputSymbols = joinNode.getOutputSymbols();
                this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
                this.planNodeIdAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
                this.pushProjectionsThroughJoin = z;
                this.session = (Session) Objects.requireNonNull(session, "session is null");
                this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
                this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
                flattenNode(joinNode, i);
            }

            private void flattenNode(PlanNode planNode, int i) {
                PlanNode resolve = this.lookup.resolve(planNode);
                if (resolve instanceof ProjectNode) {
                    if (!this.pushProjectionsThroughJoin) {
                        this.sources.add(planNode);
                        return;
                    }
                    Optional<PlanNode> pushProjectionThroughJoin = PushProjectionThroughJoin.pushProjectionThroughJoin(this.plannerContext, (ProjectNode) resolve, this.lookup, this.planNodeIdAllocator, this.session, this.typeAnalyzer, this.types);
                    if (pushProjectionThroughJoin.isEmpty()) {
                        this.sources.add(planNode);
                        return;
                    } else {
                        this.pushedProjectionThroughJoin = true;
                        flattenNode(pushProjectionThroughJoin.get(), i);
                        return;
                    }
                }
                if (resolve instanceof JoinNode) {
                    JoinNode joinNode = (JoinNode) resolve;
                    if (this.sources.size() <= i - 2) {
                        if (joinNode.getType() != JoinNode.Type.INNER || !DeterminismEvaluator.isDeterministic(joinNode.getFilter().orElse(BooleanLiteral.TRUE_LITERAL), this.plannerContext.getMetadata()) || joinNode.getDistributionType().isPresent()) {
                            this.sources.add(planNode);
                            return;
                        }
                        flattenNode(joinNode.getLeft(), i - 1);
                        flattenNode(joinNode.getRight(), i);
                        Stream<R> map = joinNode.getCriteria().stream().map((v0) -> {
                            return v0.toExpression();
                        });
                        List<Expression> list = this.filters;
                        Objects.requireNonNull(list);
                        map.forEach((v1) -> {
                            r1.add(v1);
                        });
                        Optional<Expression> filter = joinNode.getFilter();
                        List<Expression> list2 = this.filters;
                        Objects.requireNonNull(list2);
                        filter.ifPresent((v1) -> {
                            r1.add(v1);
                        });
                        return;
                    }
                }
                this.sources.add(planNode);
            }

            MultiJoinNode toMultiJoinNode() {
                return new MultiJoinNode(this.sources, ExpressionUtils.and(this.filters), this.outputSymbols, this.pushedProjectionThroughJoin);
            }
        }

        MultiJoinNode(LinkedHashSet<PlanNode> linkedHashSet, Expression expression, List<Symbol> list, boolean z) {
            Objects.requireNonNull(linkedHashSet, "sources is null");
            Preconditions.checkArgument(linkedHashSet.size() > 1, "sources size is <= 1");
            Objects.requireNonNull(expression, "filter is null");
            Objects.requireNonNull(list, "outputSymbols is null");
            this.sources = linkedHashSet;
            this.filter = expression;
            this.outputSymbols = ImmutableList.copyOf(list);
            this.pushedProjectionThroughJoin = z;
            Preconditions.checkArgument(((List) linkedHashSet.stream().flatMap(planNode -> {
                return planNode.getOutputSymbols().stream();
            }).collect(ImmutableList.toImmutableList())).containsAll(list), "inputs do not contain all output symbols");
        }

        public Expression getFilter() {
            return this.filter;
        }

        public LinkedHashSet<PlanNode> getSources() {
            return this.sources;
        }

        public List<Symbol> getOutputSymbols() {
            return this.outputSymbols;
        }

        public boolean isPushedProjectionThroughJoin() {
            return this.pushedProjectionThroughJoin;
        }

        public static Builder builder() {
            return new Builder();
        }

        public int hashCode() {
            return Objects.hash(this.sources, ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(this.filter)), this.outputSymbols, Boolean.valueOf(this.pushedProjectionThroughJoin));
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof MultiJoinNode)) {
                return false;
            }
            MultiJoinNode multiJoinNode = (MultiJoinNode) obj;
            return this.sources.equals(multiJoinNode.sources) && ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(this.filter)).equals(ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(multiJoinNode.filter))) && this.outputSymbols.equals(multiJoinNode.outputSymbols) && this.pushedProjectionThroughJoin == multiJoinNode.pushedProjectionThroughJoin;
        }

        static MultiJoinNode toMultiJoinNode(PlannerContext plannerContext, JoinNode joinNode, Rule.Context context, boolean z, TypeAnalyzer typeAnalyzer) {
            return toMultiJoinNode(plannerContext, joinNode, context.getLookup(), context.getIdAllocator(), SystemSessionProperties.getMaxReorderedJoins(context.getSession()), z, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
        }

        static MultiJoinNode toMultiJoinNode(PlannerContext plannerContext, JoinNode joinNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, int i, boolean z, Session session, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider) {
            return new JoinNodeFlattener(plannerContext, joinNode, lookup, planNodeIdAllocator, i + 1, z, session, typeAnalyzer, typeProvider).toMultiJoinNode();
        }
    }

    public ReorderJoins(PlannerContext plannerContext, CostComparator costComparator, TypeAnalyzer typeAnalyzer) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.costComparator = (CostComparator) Objects.requireNonNull(costComparator, "costComparator is null");
        this.pattern = Patterns.join().matching(joinNode -> {
            return joinNode.getDistributionType().isEmpty() && joinNode.getType() == JoinNode.Type.INNER && DeterminismEvaluator.isDeterministic(joinNode.getFilter().orElse(BooleanLiteral.TRUE_LITERAL), plannerContext.getMetadata());
        });
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return this.pattern;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getJoinReorderingStrategy(session) == OptimizerConfig.JoinReorderingStrategy.AUTOMATIC;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        MultiJoinNode multiJoinNode = MultiJoinNode.toMultiJoinNode(this.plannerContext, joinNode, context, true, this.typeAnalyzer);
        JoinEnumerationResult chooseJoinOrder = chooseJoinOrder(multiJoinNode, context);
        if (chooseJoinOrder.getPlanNode().isEmpty()) {
            return Rule.Result.empty();
        }
        if (!multiJoinNode.isPushedProjectionThroughJoin()) {
            return Rule.Result.ofPlanNode(chooseJoinOrder.getPlanNode().get());
        }
        JoinEnumerationResult chooseJoinOrder2 = chooseJoinOrder(MultiJoinNode.toMultiJoinNode(this.plannerContext, joinNode, context, false, this.typeAnalyzer), context);
        return (chooseJoinOrder2.getPlanNode().isEmpty() || this.costComparator.compare(context.getSession(), chooseJoinOrder.cost, chooseJoinOrder2.cost) < 0) ? Rule.Result.ofPlanNode(chooseJoinOrder.getPlanNode().get()) : Rule.Result.ofPlanNode(chooseJoinOrder2.getPlanNode().get());
    }

    private JoinEnumerationResult chooseJoinOrder(MultiJoinNode multiJoinNode, Rule.Context context) {
        return new JoinEnumerator(this.plannerContext.getMetadata(), this.costComparator, multiJoinNode.getFilter(), context).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols());
    }
}
