package io.trino.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import io.trino.sql.DynamicFilters;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.EffectivePredicateExtractor;
import io.trino.sql.planner.EqualityInference;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter;
import io.trino.sql.planner.iterative.rule.UnwrapCastInComparison;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SampleNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.TryExpression;
import io.trino.sql.util.AstUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/optimizations/PredicatePushDown.class */
public class PredicatePushDown implements PlanOptimizer {
    private static final Set<ComparisonExpression.Operator> DYNAMIC_FILTERING_SUPPORTED_COMPARISONS = ImmutableSet.of(ComparisonExpression.Operator.EQUAL, ComparisonExpression.Operator.GREATER_THAN, ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, ComparisonExpression.Operator.LESS_THAN, ComparisonExpression.Operator.LESS_THAN_OR_EQUAL);
    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;
    private final boolean useTableProperties;
    private final boolean dynamicFiltering;

    /* loaded from: input_file:io/trino/sql/planner/optimizations/PredicatePushDown$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Expression> {
        private final SymbolAllocator symbolAllocator;
        private final PlanNodeIdAllocator idAllocator;
        private final PlannerContext plannerContext;
        private final Metadata metadata;
        private final TypeAnalyzer typeAnalyzer;
        private final Session session;
        private final TypeProvider types;
        private final ExpressionEquivalence expressionEquivalence;
        private final boolean dynamicFiltering;
        private final LiteralEncoder literalEncoder;
        private final EffectivePredicateExtractor effectivePredicateExtractor;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/sql/planner/optimizations/PredicatePushDown$Rewriter$DynamicFilterExpression.class */
        public static class DynamicFilterExpression {
            private final ComparisonExpression comparison;
            private final boolean nullAllowed;

            private DynamicFilterExpression(ComparisonExpression comparisonExpression) {
                this(comparisonExpression, false);
            }

            private DynamicFilterExpression(ComparisonExpression comparisonExpression, boolean z) {
                this.comparison = (ComparisonExpression) Objects.requireNonNull(comparisonExpression, "comparison is null");
                this.nullAllowed = z;
            }

            public ComparisonExpression getComparison() {
                return this.comparison;
            }

            public boolean isNullAllowed() {
                return this.nullAllowed;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/sql/planner/optimizations/PredicatePushDown$Rewriter$DynamicFiltersResult.class */
        public static class DynamicFiltersResult {
            private final Map<DynamicFilterId, Symbol> dynamicFilters;
            private final List<Expression> predicates;

            public DynamicFiltersResult(Map<DynamicFilterId, Symbol> map, List<Expression> list) {
                this.dynamicFilters = ImmutableMap.copyOf(map);
                this.predicates = ImmutableList.copyOf(list);
            }

            public Map<DynamicFilterId, Symbol> getDynamicFilters() {
                return this.dynamicFilters;
            }

            public List<Expression> getPredicates() {
                return this.predicates;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/sql/planner/optimizations/PredicatePushDown$Rewriter$InnerJoinPushDownResult.class */
        public static class InnerJoinPushDownResult {
            private final Expression leftPredicate;
            private final Expression rightPredicate;
            private final Expression joinPredicate;
            private final Expression postJoinPredicate;

            private InnerJoinPushDownResult(Expression expression, Expression expression2, Expression expression3, Expression expression4) {
                this.leftPredicate = expression;
                this.rightPredicate = expression2;
                this.joinPredicate = expression3;
                this.postJoinPredicate = expression4;
            }

            private Expression getLeftPredicate() {
                return this.leftPredicate;
            }

            private Expression getRightPredicate() {
                return this.rightPredicate;
            }

            private Expression getJoinPredicate() {
                return this.joinPredicate;
            }

            private Expression getPostJoinPredicate() {
                return this.postJoinPredicate;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/sql/planner/optimizations/PredicatePushDown$Rewriter$OuterJoinPushDownResult.class */
        public static class OuterJoinPushDownResult {
            private final Expression outerJoinPredicate;
            private final Expression innerJoinPredicate;
            private final Expression joinPredicate;
            private final Expression postJoinPredicate;

            private OuterJoinPushDownResult(Expression expression, Expression expression2, Expression expression3, Expression expression4) {
                this.outerJoinPredicate = expression;
                this.innerJoinPredicate = expression2;
                this.joinPredicate = expression3;
                this.postJoinPredicate = expression4;
            }

            private Expression getOuterJoinPredicate() {
                return this.outerJoinPredicate;
            }

            private Expression getInnerJoinPredicate() {
                return this.innerJoinPredicate;
            }

            public Expression getJoinPredicate() {
                return this.joinPredicate;
            }

            private Expression getPostJoinPredicate() {
                return this.postJoinPredicate;
            }
        }

        private Rewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider typeProvider, boolean z, boolean z2) {
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.metadata = plannerContext.getMetadata();
            this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.expressionEquivalence = new ExpressionEquivalence(plannerContext.getMetadata(), plannerContext.getFunctionManager(), typeAnalyzer);
            this.dynamicFiltering = z2;
            this.effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(plannerContext), plannerContext, z && SystemSessionProperties.isPredicatePushdownUseTableProperties(session));
            this.literalEncoder = new LiteralEncoder(plannerContext);
        }

        @Override // io.trino.sql.planner.plan.SimplePlanRewriter, io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(planNode, BooleanLiteral.TRUE_LITERAL);
            if (!rewriteContext.get().equals(BooleanLiteral.TRUE_LITERAL)) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, rewriteContext.get());
            }
            return defaultRewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            boolean z = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < exchangeNode.getSources().size(); i++) {
                HashMap hashMap = new HashMap();
                for (int i2 = 0; i2 < exchangeNode.getInputs().get(i).size(); i2++) {
                    hashMap.put(exchangeNode.getOutputSymbols().get(i2), exchangeNode.getInputs().get(i).get(i2).toSymbolReference());
                }
                Expression inlineSymbols = ExpressionSymbolInliner.inlineSymbols(hashMap, rewriteContext.get());
                PlanNode planNode = exchangeNode.getSources().get(i);
                PlanNode rewrite = rewriteContext.rewrite(planNode, inlineSymbols);
                if (rewrite != planNode) {
                    z = true;
                }
                builder.add(rewrite);
            }
            return z ? new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), exchangeNode.getPartitioningScheme(), builder.build(), exchangeNode.getInputs(), exchangeNode.getOrderingScheme()) : exchangeNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitWindow(WindowNode windowNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            List<Symbol> partitionBy = windowNode.getPartitionBy();
            Map map = (Map) ExpressionUtils.extractConjuncts(rewriteContext.get()).stream().collect(Collectors.partitioningBy(expression -> {
                return DeterminismEvaluator.isDeterministic(expression, this.metadata) && partitionBy.containsAll(SymbolsExtractor.extractUnique(expression));
            }));
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(windowNode, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) map.get(true)));
            if (!((List) map.get(false)).isEmpty()) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) map.get(false)));
            }
            return defaultRewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitProject(ProjectNode projectNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Set set = (Set) projectNode.getAssignments().entrySet().stream().filter(entry -> {
                return DeterminismEvaluator.isDeterministic((Expression) entry.getValue(), this.metadata);
            }).map((v0) -> {
                return v0.getKey();
            }).collect(Collectors.toSet());
            Map map = (Map) ExpressionUtils.extractConjuncts(rewriteContext.get()).stream().collect(Collectors.partitioningBy(expression -> {
                return set.containsAll(SymbolsExtractor.extractUnique(expression));
            }));
            Map map2 = (Map) ((List) map.get(true)).stream().collect(Collectors.partitioningBy(expression2 -> {
                return isInliningCandidate(expression2, projectNode);
            }));
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(projectNode, ExpressionUtils.combineConjuncts(this.metadata, (List) ((List) map2.get(true)).stream().map(expression3 -> {
                return ExpressionSymbolInliner.inlineSymbols((Map<Symbol, ? extends Expression>) projectNode.getAssignments().getMap(), expression3);
            }).map(expression4 -> {
                return CanonicalizeExpressionRewriter.canonicalizeExpression(expression4, this.typeAnalyzer.getTypes(this.session, this.types, expression4), this.plannerContext, this.session);
            }).map(expression5 -> {
                return UnwrapCastInComparison.unwrapCasts(this.session, this.plannerContext, this.typeAnalyzer, this.types, expression5);
            }).collect(Collectors.toList())));
            List list = (List) map2.get(false);
            list.addAll((Collection) map.get(false));
            if (!list.isEmpty()) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, ExpressionUtils.combineConjuncts(this.metadata, list));
            }
            return defaultRewrite;
        }

        private boolean isInliningCandidate(Expression expression, ProjectNode projectNode) {
            Stream preOrder = AstUtils.preOrder(expression);
            Class<TryExpression> cls = TryExpression.class;
            Objects.requireNonNull(TryExpression.class);
            Verify.verify(preOrder.noneMatch((v1) -> {
                return r1.isInstance(v1);
            }));
            ImmutableSet copyOf = ImmutableSet.copyOf(projectNode.getOutputSymbols());
            Stream<Symbol> stream = SymbolsExtractor.extractAll(expression).stream();
            Objects.requireNonNull(copyOf);
            return ((Map) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))).entrySet().stream().allMatch(entry -> {
                return ((Long) entry.getValue()).longValue() == 1 || ExpressionUtils.isEffectivelyLiteral(this.plannerContext, this.session, projectNode.getAssignments().get((Symbol) entry.getKey())) || (projectNode.getAssignments().get((Symbol) entry.getKey()) instanceof SymbolReference);
            });
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitGroupId(GroupIdNode groupIdNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Map map = (Map) groupIdNode.getGroupingColumns().entrySet().stream().filter(entry -> {
                return groupIdNode.getCommonGroupingColumns().contains(entry.getKey());
            }).collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, entry2 -> {
                return ((Symbol) entry2.getValue()).toSymbolReference();
            }));
            Map map2 = (Map) ExpressionUtils.extractConjuncts(rewriteContext.get()).stream().collect(Collectors.partitioningBy(expression -> {
                return map.keySet().containsAll(SymbolsExtractor.extractUnique(expression));
            }));
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(groupIdNode, ExpressionSymbolInliner.inlineSymbols((Map<Symbol, ? extends Expression>) map, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) map2.get(true))));
            if (!((List) map2.get(false)).isEmpty()) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) map2.get(false)));
            }
            return defaultRewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            ImmutableSet copyOf = ImmutableSet.copyOf(markDistinctNode.getDistinctSymbols());
            Map map = (Map) ExpressionUtils.extractConjuncts(rewriteContext.get()).stream().collect(Collectors.partitioningBy(expression -> {
                return copyOf.containsAll(SymbolsExtractor.extractUnique(expression));
            }));
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(markDistinctNode, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) map.get(true)));
            if (!((List) map.get(false)).isEmpty()) {
                defaultRewrite = new FilterNode(this.idAllocator.getNextId(), defaultRewrite, ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) map.get(false)));
            }
            return defaultRewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitSort(SortNode sortNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            return rewriteContext.defaultRewrite(sortNode, rewriteContext.get());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitUnion(UnionNode unionNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            boolean z = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < unionNode.getSources().size(); i++) {
                Expression inlineSymbols = ExpressionSymbolInliner.inlineSymbols((Map<Symbol, ? extends Expression>) unionNode.sourceSymbolMap(i), rewriteContext.get());
                PlanNode planNode = unionNode.getSources().get(i);
                PlanNode rewrite = rewriteContext.rewrite(planNode, inlineSymbols);
                if (rewrite != planNode) {
                    z = true;
                }
                builder.add(rewrite);
            }
            return z ? new UnionNode(unionNode.getId(), builder.build(), unionNode.getSymbolMapping(), unionNode.getOutputSymbols()) : unionNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        @Deprecated
        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(filterNode.getSource(), ExpressionUtils.combineConjuncts(this.metadata, filterNode.getPredicate(), rewriteContext.get()));
            if (!(rewrite instanceof FilterNode)) {
                return rewrite;
            }
            FilterNode filterNode2 = (FilterNode) rewrite;
            return (areExpressionsEquivalent(filterNode2.getPredicate(), filterNode.getPredicate()) && filterNode.getSource() == filterNode2.getSource()) ? filterNode : rewrite;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression;
            Expression expression2;
            Expression expression3;
            Expression expression4;
            PlanNode rewrite;
            PlanNode rewrite2;
            Expression expression5 = rewriteContext.get();
            JoinNode tryNormalizeToOuterToInnerJoin = tryNormalizeToOuterToInnerJoin(joinNode, expression5);
            Expression extract = this.effectivePredicateExtractor.extract(this.session, tryNormalizeToOuterToInnerJoin.getLeft(), this.types, this.typeAnalyzer);
            Expression extract2 = this.effectivePredicateExtractor.extract(this.session, tryNormalizeToOuterToInnerJoin.getRight(), this.types, this.typeAnalyzer);
            Expression extractJoinPredicate = extractJoinPredicate(tryNormalizeToOuterToInnerJoin);
            switch (tryNormalizeToOuterToInnerJoin.getType()) {
                case INNER:
                    InnerJoinPushDownResult processInnerJoin = processInnerJoin(expression5, extract, extract2, extractJoinPredicate, tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols(), tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols());
                    expression = processInnerJoin.getLeftPredicate();
                    expression2 = processInnerJoin.getRightPredicate();
                    expression3 = processInnerJoin.getPostJoinPredicate();
                    expression4 = processInnerJoin.getJoinPredicate();
                    break;
                case LEFT:
                    OuterJoinPushDownResult processLimitedOuterJoin = processLimitedOuterJoin(expression5, extract, extract2, extractJoinPredicate, tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols(), tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols());
                    expression = processLimitedOuterJoin.getOuterJoinPredicate();
                    expression2 = processLimitedOuterJoin.getInnerJoinPredicate();
                    expression3 = processLimitedOuterJoin.getPostJoinPredicate();
                    expression4 = processLimitedOuterJoin.getJoinPredicate();
                    break;
                case RIGHT:
                    OuterJoinPushDownResult processLimitedOuterJoin2 = processLimitedOuterJoin(expression5, extract2, extract, extractJoinPredicate, tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols(), tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols());
                    expression = processLimitedOuterJoin2.getInnerJoinPredicate();
                    expression2 = processLimitedOuterJoin2.getOuterJoinPredicate();
                    expression3 = processLimitedOuterJoin2.getPostJoinPredicate();
                    expression4 = processLimitedOuterJoin2.getJoinPredicate();
                    break;
                case FULL:
                    expression = BooleanLiteral.TRUE_LITERAL;
                    expression2 = BooleanLiteral.TRUE_LITERAL;
                    expression3 = expression5;
                    expression4 = extractJoinPredicate;
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported join type: " + tryNormalizeToOuterToInnerJoin.getType());
            }
            Expression simplifyExpression = simplifyExpression(expression4);
            Assignments.Builder builder = Assignments.builder();
            builder.putAll((Map<Symbol, ? extends Expression>) tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols().stream().collect(ImmutableMap.toImmutableMap(symbol -> {
                return symbol;
            }, (v0) -> {
                return v0.toSymbolReference();
            })));
            Assignments.Builder builder2 = Assignments.builder();
            builder2.putAll((Map<Symbol, ? extends Expression>) tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols().stream().collect(ImmutableMap.toImmutableMap(symbol2 -> {
                return symbol2;
            }, (v0) -> {
                return v0.toSymbolReference();
            })));
            ArrayList arrayList = new ArrayList();
            ImmutableList.Builder builder3 = ImmutableList.builder();
            Iterator<Expression> it = ExpressionUtils.extractConjuncts(simplifyExpression).iterator();
            while (it.hasNext()) {
                ComparisonExpression comparisonExpression = (Expression) it.next();
                if (joinEqualityExpression(comparisonExpression, tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols(), tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols())) {
                    ComparisonExpression comparisonExpression2 = comparisonExpression;
                    boolean containsAll = tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols().containsAll(SymbolsExtractor.extractUnique(comparisonExpression2.getLeft()));
                    Expression left = containsAll ? comparisonExpression2.getLeft() : comparisonExpression2.getRight();
                    Expression right = containsAll ? comparisonExpression2.getRight() : comparisonExpression2.getLeft();
                    Symbol symbolForExpression = symbolForExpression(left);
                    if (!tryNormalizeToOuterToInnerJoin.getLeft().getOutputSymbols().contains(symbolForExpression)) {
                        builder.put(symbolForExpression, left);
                    }
                    Symbol symbolForExpression2 = symbolForExpression(right);
                    if (!tryNormalizeToOuterToInnerJoin.getRight().getOutputSymbols().contains(symbolForExpression2)) {
                        builder2.put(symbolForExpression2, right);
                    }
                    arrayList.add(new JoinNode.EquiJoinClause(symbolForExpression, symbolForExpression2));
                } else {
                    builder3.add(comparisonExpression);
                }
            }
            ImmutableList build = builder3.build();
            DynamicFiltersResult createDynamicFilters = createDynamicFilters(tryNormalizeToOuterToInnerJoin, arrayList, build, this.session, this.idAllocator);
            Map<DynamicFilterId, Symbol> dynamicFilters = createDynamicFilters.getDynamicFilters();
            Expression combineConjuncts = ExpressionUtils.combineConjuncts(this.metadata, expression, ExpressionUtils.combineConjuncts(this.metadata, createDynamicFilters.getPredicates()));
            boolean equals = ImmutableSet.copyOf(arrayList).equals(ImmutableSet.copyOf(tryNormalizeToOuterToInnerJoin.getCriteria()));
            if (equals) {
                rewrite = rewriteContext.rewrite(tryNormalizeToOuterToInnerJoin.getLeft(), combineConjuncts);
                rewrite2 = rewriteContext.rewrite(tryNormalizeToOuterToInnerJoin.getRight(), expression2);
            } else {
                rewrite = rewriteContext.rewrite(new ProjectNode(this.idAllocator.getNextId(), tryNormalizeToOuterToInnerJoin.getLeft(), builder.build()), combineConjuncts);
                rewrite2 = rewriteContext.rewrite(new ProjectNode(this.idAllocator.getNextId(), tryNormalizeToOuterToInnerJoin.getRight(), builder2.build()), expression2);
            }
            Optional of = Optional.of(ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) build));
            if (((Expression) of.get()).equals(BooleanLiteral.TRUE_LITERAL)) {
                of = Optional.empty();
            }
            if (tryNormalizeToOuterToInnerJoin.getType() == JoinNode.Type.INNER && of.isPresent() && arrayList.isEmpty()) {
                expression3 = ExpressionUtils.combineConjuncts(this.metadata, expression3, (Expression) of.get());
                of = Optional.empty();
            }
            boolean z = of.isPresent() == tryNormalizeToOuterToInnerJoin.getFilter().isPresent() && (of.isEmpty() || areExpressionsEquivalent((Expression) of.get(), tryNormalizeToOuterToInnerJoin.getFilter().get()));
            PlanNode planNode = tryNormalizeToOuterToInnerJoin;
            if (rewrite != tryNormalizeToOuterToInnerJoin.getLeft() || rewrite2 != tryNormalizeToOuterToInnerJoin.getRight() || !z || !dynamicFilters.equals(tryNormalizeToOuterToInnerJoin.getDynamicFilters()) || !equals) {
                ProjectNode projectNode = new ProjectNode(this.idAllocator.getNextId(), rewrite, builder.build());
                ProjectNode projectNode2 = new ProjectNode(this.idAllocator.getNextId(), rewrite2, builder2.build());
                planNode = new JoinNode(tryNormalizeToOuterToInnerJoin.getId(), tryNormalizeToOuterToInnerJoin.getType(), projectNode, projectNode2, arrayList, projectNode.getOutputSymbols(), projectNode2.getOutputSymbols(), tryNormalizeToOuterToInnerJoin.isMaySkipOutputDuplicates(), of, tryNormalizeToOuterToInnerJoin.getLeftHashSymbol(), tryNormalizeToOuterToInnerJoin.getRightHashSymbol(), tryNormalizeToOuterToInnerJoin.getDistributionType(), tryNormalizeToOuterToInnerJoin.isSpillable(), dynamicFilters, tryNormalizeToOuterToInnerJoin.getReorderJoinStatsAndCost());
            }
            if (!expression3.equals(BooleanLiteral.TRUE_LITERAL)) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, expression3);
            }
            if (!tryNormalizeToOuterToInnerJoin.getOutputSymbols().equals(planNode.getOutputSymbols())) {
                planNode = new ProjectNode(this.idAllocator.getNextId(), planNode, Assignments.identity(tryNormalizeToOuterToInnerJoin.getOutputSymbols()));
            }
            return planNode;
        }

        private DynamicFiltersResult createDynamicFilters(JoinNode joinNode, List<JoinNode.EquiJoinClause> list, List<Expression> list2, Session session, PlanNodeIdAllocator planNodeIdAllocator) {
            if ((joinNode.getType() != JoinNode.Type.INNER && joinNode.getType() != JoinNode.Type.RIGHT) || !SystemSessionProperties.isEnableDynamicFiltering(session) || !this.dynamicFiltering) {
                return new DynamicFiltersResult(ImmutableMap.of(), ImmutableList.of());
            }
            List list3 = (List) Streams.concat(new Stream[]{list.stream().map(equiJoinClause -> {
                return new DynamicFilterExpression(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, equiJoinClause.getLeft().toSymbolReference(), equiJoinClause.getRight().toSymbolReference()));
            }), list2.stream().flatMap(Rewriter::tryConvertBetweenIntoComparisons).filter(expression -> {
                return joinDynamicFilteringExpression(expression, joinNode.getLeft().getOutputSymbols(), joinNode.getRight().getOutputSymbols());
            }).map(expression2 -> {
                if (!(expression2 instanceof NotExpression)) {
                    return new DynamicFilterExpression((ComparisonExpression) expression2);
                }
                ComparisonExpression value = ((NotExpression) expression2).getValue();
                return new DynamicFilterExpression(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, value.getLeft(), value.getRight()), true);
            }).map(dynamicFilterExpression -> {
                ComparisonExpression comparison = dynamicFilterExpression.getComparison();
                Expression left = comparison.getLeft();
                Expression right = comparison.getRight();
                boolean containsAll = joinNode.getLeft().getOutputSymbols().containsAll(SymbolsExtractor.extractUnique(left));
                return new DynamicFilterExpression(new ComparisonExpression(containsAll ? comparison.getOperator() : comparison.getOperator().flip(), containsAll ? left : right, containsAll ? right : left), dynamicFilterExpression.isNullAllowed());
            })}).collect(ImmutableList.toImmutableList());
            Set set = (Set) list3.stream().map((v0) -> {
                return v0.getComparison();
            }).map((v0) -> {
                return v0.getRight();
            }).map(Symbol::from).collect(ImmutableSet.toImmutableSet());
            BiMap inverse = HashBiMap.create(joinNode.getDynamicFilters()).inverse();
            Iterator it = set.iterator();
            while (it.hasNext()) {
                inverse.computeIfAbsent((Symbol) it.next(), symbol -> {
                    return new DynamicFilterId("df_" + planNodeIdAllocator.getNextId().toString());
                });
            }
            return new DynamicFiltersResult(inverse.inverse(), (List) list3.stream().map(dynamicFilterExpression2 -> {
                ComparisonExpression comparison = dynamicFilterExpression2.getComparison();
                Expression left = comparison.getLeft();
                Symbol from = Symbol.from(comparison.getRight());
                Type type = this.symbolAllocator.getTypes().get(from);
                return DynamicFilters.createDynamicFilterExpression(this.metadata, (DynamicFilterId) Objects.requireNonNull((DynamicFilterId) inverse.get(from), (Supplier<String>) () -> {
                    return "missing dynamic filter for symbol " + from;
                }), type, left, comparison.getOperator(), dynamicFilterExpression2.isNullAllowed());
            }).collect(ImmutableList.toImmutableList()));
        }

        private static Stream<Expression> tryConvertBetweenIntoComparisons(Expression expression) {
            if (!(expression instanceof BetweenPredicate)) {
                return Stream.of(expression);
            }
            BetweenPredicate betweenPredicate = (BetweenPredicate) expression;
            return Stream.of((Object[]) new Expression[]{new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMin()), new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMax())});
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitSpatialJoin(SpatialJoinNode spatialJoinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression outerJoinPredicate;
            Expression innerJoinPredicate;
            Expression postJoinPredicate;
            Expression joinPredicate;
            Expression expression = rewriteContext.get();
            if (spatialJoinNode.getType() == SpatialJoinNode.Type.LEFT && canConvertOuterToInner(spatialJoinNode.getRight().getOutputSymbols(), expression)) {
                spatialJoinNode = new SpatialJoinNode(spatialJoinNode.getId(), SpatialJoinNode.Type.INNER, spatialJoinNode.getLeft(), spatialJoinNode.getRight(), spatialJoinNode.getOutputSymbols(), spatialJoinNode.getFilter(), spatialJoinNode.getLeftPartitionSymbol(), spatialJoinNode.getRightPartitionSymbol(), spatialJoinNode.getKdbTree());
            }
            Expression extract = this.effectivePredicateExtractor.extract(this.session, spatialJoinNode.getLeft(), this.types, this.typeAnalyzer);
            Expression extract2 = this.effectivePredicateExtractor.extract(this.session, spatialJoinNode.getRight(), this.types, this.typeAnalyzer);
            Expression filter = spatialJoinNode.getFilter();
            switch (spatialJoinNode.getType()) {
                case INNER:
                    InnerJoinPushDownResult processInnerJoin = processInnerJoin(expression, extract, extract2, filter, spatialJoinNode.getLeft().getOutputSymbols(), spatialJoinNode.getRight().getOutputSymbols());
                    outerJoinPredicate = processInnerJoin.getLeftPredicate();
                    innerJoinPredicate = processInnerJoin.getRightPredicate();
                    postJoinPredicate = processInnerJoin.getPostJoinPredicate();
                    joinPredicate = processInnerJoin.getJoinPredicate();
                    break;
                case LEFT:
                    OuterJoinPushDownResult processLimitedOuterJoin = processLimitedOuterJoin(expression, extract, extract2, filter, spatialJoinNode.getLeft().getOutputSymbols(), spatialJoinNode.getRight().getOutputSymbols());
                    outerJoinPredicate = processLimitedOuterJoin.getOuterJoinPredicate();
                    innerJoinPredicate = processLimitedOuterJoin.getInnerJoinPredicate();
                    postJoinPredicate = processLimitedOuterJoin.getPostJoinPredicate();
                    joinPredicate = processLimitedOuterJoin.getJoinPredicate();
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported spatial join type: " + spatialJoinNode.getType());
            }
            Expression simplifyExpression = simplifyExpression(joinPredicate);
            Verify.verify(!simplifyExpression.equals(BooleanLiteral.FALSE_LITERAL), "Spatial join predicate is missing", new Object[0]);
            PlanNode rewrite = rewriteContext.rewrite(spatialJoinNode.getLeft(), outerJoinPredicate);
            PlanNode rewrite2 = rewriteContext.rewrite(spatialJoinNode.getRight(), innerJoinPredicate);
            PlanNode planNode = spatialJoinNode;
            if (rewrite != spatialJoinNode.getLeft() || rewrite2 != spatialJoinNode.getRight() || !areExpressionsEquivalent(simplifyExpression, filter)) {
                Assignments.Builder builder = Assignments.builder();
                builder.putAll((Map<Symbol, ? extends Expression>) spatialJoinNode.getLeft().getOutputSymbols().stream().collect(ImmutableMap.toImmutableMap(symbol -> {
                    return symbol;
                }, (v0) -> {
                    return v0.toSymbolReference();
                })));
                Assignments.Builder builder2 = Assignments.builder();
                builder2.putAll((Map<Symbol, ? extends Expression>) spatialJoinNode.getRight().getOutputSymbols().stream().collect(ImmutableMap.toImmutableMap(symbol2 -> {
                    return symbol2;
                }, (v0) -> {
                    return v0.toSymbolReference();
                })));
                planNode = new SpatialJoinNode(spatialJoinNode.getId(), spatialJoinNode.getType(), new ProjectNode(this.idAllocator.getNextId(), rewrite, builder.build()), new ProjectNode(this.idAllocator.getNextId(), rewrite2, builder2.build()), spatialJoinNode.getOutputSymbols(), simplifyExpression, spatialJoinNode.getLeftPartitionSymbol(), spatialJoinNode.getRightPartitionSymbol(), spatialJoinNode.getKdbTree());
            }
            if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, postJoinPredicate);
            }
            return planNode;
        }

        private Symbol symbolForExpression(Expression expression) {
            return expression instanceof SymbolReference ? Symbol.from(expression) : this.symbolAllocator.newSymbol(expression, this.typeAnalyzer.getType(this.session, this.symbolAllocator.getTypes(), expression));
        }

        private OuterJoinPushDownResult processLimitedOuterJoin(Expression expression, Expression expression2, Expression expression3, Expression expression4, Collection<Symbol> collection, Collection<Symbol> collection2) {
            Preconditions.checkArgument(collection.containsAll(SymbolsExtractor.extractUnique(expression2)), "outerEffectivePredicate must only contain symbols from outerSymbols");
            Preconditions.checkArgument(collection2.containsAll(SymbolsExtractor.extractUnique(expression3)), "innerEffectivePredicate must only contain symbols from innerSymbols");
            ImmutableList.Builder builder = ImmutableList.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            ImmutableList.Builder builder3 = ImmutableList.builder();
            ImmutableList.Builder builder4 = ImmutableList.builder();
            Stream<Expression> filter = ExpressionUtils.extractConjuncts(expression).stream().filter(expression5 -> {
                return !DeterminismEvaluator.isDeterministic(expression5, this.metadata);
            });
            Objects.requireNonNull(builder3);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            Expression filterDeterministicConjuncts = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression);
            Expression filterDeterministicConjuncts2 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression2);
            Expression filterDeterministicConjuncts3 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression3);
            Stream<Expression> filter2 = ExpressionUtils.extractConjuncts(expression4).stream().filter(expression6 -> {
                return !DeterminismEvaluator.isDeterministic(expression6, this.metadata);
            });
            Objects.requireNonNull(builder4);
            filter2.forEach((v1) -> {
                r1.add(v1);
            });
            Expression filterDeterministicConjuncts4 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression4);
            EqualityInference equalityInference = new EqualityInference(this.metadata, filterDeterministicConjuncts);
            EqualityInference equalityInference2 = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts2);
            Set<Symbol> copyOf = ImmutableSet.copyOf(collection2);
            ImmutableSet copyOf2 = ImmutableSet.copyOf(collection);
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = equalityInference.generateEqualitiesPartitionedBy(copyOf2);
            Expression combineConjuncts = ExpressionUtils.combineConjuncts(this.metadata, generateEqualitiesPartitionedBy.getScopeEqualities());
            EqualityInference equalityInference3 = new EqualityInference(this.metadata, combineConjuncts, filterDeterministicConjuncts2, filterDeterministicConjuncts3, filterDeterministicConjuncts4);
            builder2.addAll(new EqualityInference(this.metadata, combineConjuncts, filterDeterministicConjuncts2, filterDeterministicConjuncts4).generateEqualitiesPartitionedBy(copyOf).getScopeEqualities());
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy2 = new EqualityInference(this.metadata, filterDeterministicConjuncts4).generateEqualitiesPartitionedBy(copyOf);
            builder2.addAll(generateEqualitiesPartitionedBy2.getScopeEqualities());
            builder4.addAll(generateEqualitiesPartitionedBy2.getScopeComplementEqualities()).addAll(generateEqualitiesPartitionedBy2.getScopeStraddlingEqualities());
            builder.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            builder3.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            builder3.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts).forEach(expression7 -> {
                Expression rewrite = equalityInference2.rewrite(expression7, copyOf2);
                if (rewrite == null) {
                    builder3.add(expression7);
                    return;
                }
                builder.add(rewrite);
                Expression rewrite2 = equalityInference3.rewrite(rewrite, copyOf);
                if (rewrite2 != null) {
                    builder2.add(rewrite2);
                }
            });
            Stream filter3 = EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts2).map(expression8 -> {
                return equalityInference3.rewrite(expression8, copyOf);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(builder2);
            filter3.forEach((v1) -> {
                r1.add(v1);
            });
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts4).forEach(expression9 -> {
                Expression rewrite = equalityInference3.rewrite(expression9, copyOf);
                if (rewrite != null) {
                    builder2.add(rewrite);
                } else {
                    builder4.add(expression9);
                }
            });
            return new OuterJoinPushDownResult(ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder.build()), ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder2.build()), ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder4.build()), ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder3.build()));
        }

        private InnerJoinPushDownResult processInnerJoin(Expression expression, Expression expression2, Expression expression3, Expression expression4, Collection<Symbol> collection, Collection<Symbol> collection2) {
            Preconditions.checkArgument(collection.containsAll(SymbolsExtractor.extractUnique(expression2)), "leftEffectivePredicate must only contain symbols from leftSymbols");
            Preconditions.checkArgument(collection2.containsAll(SymbolsExtractor.extractUnique(expression3)), "rightEffectivePredicate must only contain symbols from rightSymbols");
            ImmutableList.Builder builder = ImmutableList.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            ImmutableList.Builder builder3 = ImmutableList.builder();
            Stream<Expression> filter = ExpressionUtils.extractConjuncts(expression).stream().filter(expression5 -> {
                return !DeterminismEvaluator.isDeterministic(expression5, this.metadata);
            });
            Objects.requireNonNull(builder3);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            Expression filterDeterministicConjuncts = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression);
            Stream<Expression> filter2 = ExpressionUtils.extractConjuncts(expression4).stream().filter(expression6 -> {
                return !DeterminismEvaluator.isDeterministic(expression6, this.metadata);
            });
            Objects.requireNonNull(builder3);
            filter2.forEach((v1) -> {
                r1.add(v1);
            });
            Expression filterDeterministicConjuncts2 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression4);
            Expression filterDeterministicConjuncts3 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression2);
            Expression filterDeterministicConjuncts4 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression3);
            Set<Symbol> copyOf = ImmutableSet.copyOf(collection);
            Set<Symbol> copyOf2 = ImmutableSet.copyOf(collection2);
            EqualityInference equalityInference = new EqualityInference(this.metadata, filterDeterministicConjuncts);
            Expression rewrite = equalityInference.rewrite(filterDeterministicConjuncts3, copyOf);
            Expression rewrite2 = equalityInference.rewrite(filterDeterministicConjuncts4, copyOf2);
            EqualityInference equalityInference2 = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts3, filterDeterministicConjuncts4, filterDeterministicConjuncts2, rewrite, rewrite2);
            EqualityInference equalityInference3 = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts4, filterDeterministicConjuncts2, rewrite2);
            EqualityInference equalityInference4 = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts3, filterDeterministicConjuncts2, rewrite);
            builder.addAll(equalityInference3.generateEqualitiesPartitionedBy(copyOf).getScopeEqualities());
            builder2.addAll(equalityInference4.generateEqualitiesPartitionedBy(copyOf2).getScopeEqualities());
            builder3.addAll(equalityInference2.generateEqualitiesPartitionedBy(copyOf).getScopeStraddlingEqualities());
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts).forEach(expression7 -> {
                Expression rewrite3 = equalityInference2.rewrite(expression7, copyOf);
                if (rewrite3 != null) {
                    builder.add(rewrite3);
                }
                Expression rewrite4 = equalityInference2.rewrite(expression7, copyOf2);
                if (rewrite4 != null) {
                    builder2.add(rewrite4);
                }
                if (rewrite3 == null && rewrite4 == null) {
                    builder3.add(expression7);
                }
            });
            Stream filter3 = EqualityInference.nonInferrableConjuncts(this.metadata, rewrite2).map(expression8 -> {
                return equalityInference2.rewrite(expression8, copyOf);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(builder);
            filter3.forEach((v1) -> {
                r1.add(v1);
            });
            Stream filter4 = EqualityInference.nonInferrableConjuncts(this.metadata, rewrite).map(expression9 -> {
                return equalityInference2.rewrite(expression9, copyOf2);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(builder2);
            filter4.forEach((v1) -> {
                r1.add(v1);
            });
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts2).forEach(expression10 -> {
                Expression rewrite3 = equalityInference2.rewrite(expression10, copyOf);
                if (rewrite3 != null) {
                    builder.add(rewrite3);
                }
                Expression rewrite4 = equalityInference2.rewrite(expression10, copyOf2);
                if (rewrite4 != null) {
                    builder2.add(rewrite4);
                }
                if (rewrite3 == null && rewrite4 == null) {
                    builder3.add(expression10);
                }
            });
            return new InnerJoinPushDownResult(ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder.build()), ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder2.build()), ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder3.build()), BooleanLiteral.TRUE_LITERAL);
        }

        private Expression extractJoinPredicate(JoinNode joinNode) {
            ImmutableList.Builder builder = ImmutableList.builder();
            Iterator<JoinNode.EquiJoinClause> it = joinNode.getCriteria().iterator();
            while (it.hasNext()) {
                builder.add(it.next().toExpression());
            }
            Optional<Expression> filter = joinNode.getFilter();
            Objects.requireNonNull(builder);
            filter.ifPresent((v1) -> {
                r1.add(v1);
            });
            return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>) builder.build());
        }

        private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode joinNode, Expression expression) {
            Preconditions.checkArgument(EnumSet.of(JoinNode.Type.INNER, JoinNode.Type.RIGHT, JoinNode.Type.LEFT, JoinNode.Type.FULL).contains(joinNode.getType()), "Unsupported join type: %s", joinNode.getType());
            if (joinNode.getType() == JoinNode.Type.INNER) {
                return joinNode;
            }
            if (joinNode.getType() != JoinNode.Type.FULL) {
                return ((joinNode.getType() != JoinNode.Type.LEFT || canConvertOuterToInner(joinNode.getRight().getOutputSymbols(), expression)) && (joinNode.getType() != JoinNode.Type.RIGHT || canConvertOuterToInner(joinNode.getLeft().getOutputSymbols(), expression))) ? new JoinNode(joinNode.getId(), JoinNode.Type.INNER, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()) : joinNode;
            }
            boolean canConvertOuterToInner = canConvertOuterToInner(joinNode.getLeft().getOutputSymbols(), expression);
            boolean canConvertOuterToInner2 = canConvertOuterToInner(joinNode.getRight().getOutputSymbols(), expression);
            if (!canConvertOuterToInner && !canConvertOuterToInner2) {
                return joinNode;
            }
            if (canConvertOuterToInner && canConvertOuterToInner2) {
                return new JoinNode(joinNode.getId(), JoinNode.Type.INNER, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
            }
            return new JoinNode(joinNode.getId(), canConvertOuterToInner ? JoinNode.Type.LEFT : JoinNode.Type.RIGHT, joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
        }

        private boolean canConvertOuterToInner(List<Symbol> list, Expression expression) {
            Object nullInputEvaluator;
            ImmutableSet copyOf = ImmutableSet.copyOf(list);
            for (Expression expression2 : ExpressionUtils.extractConjuncts(expression)) {
                if (DeterminismEvaluator.isDeterministic(expression2, this.metadata) && ((nullInputEvaluator = nullInputEvaluator(copyOf, expression2)) == null || (nullInputEvaluator instanceof NullLiteral) || Boolean.FALSE.equals(nullInputEvaluator))) {
                    return true;
                }
            }
            return false;
        }

        private Expression simplifyExpression(Expression expression) {
            Map<NodeRef<Expression>, Type> types = this.typeAnalyzer.getTypes(this.session, this.symbolAllocator.getTypes(), expression);
            return this.literalEncoder.toExpression(new ExpressionInterpreter(expression, this.plannerContext, this.session, types).optimize(NoOpSymbolResolver.INSTANCE), types.get(NodeRef.of(expression)));
        }

        private boolean areExpressionsEquivalent(Expression expression, Expression expression2) {
            return this.expressionEquivalence.areExpressionsEquivalent(this.session, expression, expression2, this.types);
        }

        private Object nullInputEvaluator(Collection<Symbol> collection, Expression expression) {
            return new ExpressionInterpreter(expression, this.plannerContext, this.session, this.typeAnalyzer.getTypes(this.session, this.symbolAllocator.getTypes(), expression)).optimize(symbol -> {
                if (collection.contains(symbol)) {
                    return null;
                }
                return symbol.toSymbolReference();
            });
        }

        private boolean joinEqualityExpression(Expression expression, Collection<Symbol> collection, Collection<Symbol> collection2) {
            return joinComparisonExpression(expression, collection, collection2, ImmutableSet.of(ComparisonExpression.Operator.EQUAL));
        }

        private boolean joinDynamicFilteringExpression(Expression expression, Collection<Symbol> collection, Collection<Symbol> collection2) {
            ComparisonExpression comparisonExpression;
            if (expression instanceof NotExpression) {
                NotExpression notExpression = (NotExpression) expression;
                if (!joinComparisonExpression(notExpression.getValue(), collection, collection2, ImmutableSet.of(ComparisonExpression.Operator.IS_DISTINCT_FROM))) {
                    return false;
                }
                comparisonExpression = (ComparisonExpression) notExpression.getValue();
                ImmutableSet of = ImmutableSet.of(this.typeAnalyzer.getType(this.session, this.types, comparisonExpression.getLeft()), this.typeAnalyzer.getType(this.session, this.types, comparisonExpression.getRight()));
                if (of.contains(RealType.REAL) || of.contains(DoubleType.DOUBLE)) {
                    return false;
                }
            } else {
                if (!joinComparisonExpression(expression, collection, collection2, PredicatePushDown.DYNAMIC_FILTERING_SUPPORTED_COMPARISONS)) {
                    return false;
                }
                comparisonExpression = (ComparisonExpression) expression;
            }
            return ((comparisonExpression.getRight() instanceof SymbolReference) && collection2.contains(Symbol.from(comparisonExpression.getRight()))) || ((comparisonExpression.getLeft() instanceof SymbolReference) && collection2.contains(Symbol.from(comparisonExpression.getLeft())));
        }

        private boolean joinComparisonExpression(Expression expression, Collection<Symbol> collection, Collection<Symbol> collection2, Set<ComparisonExpression.Operator> set) {
            if (!(expression instanceof ComparisonExpression)) {
                return false;
            }
            ComparisonExpression comparisonExpression = (ComparisonExpression) expression;
            if (!DeterminismEvaluator.isDeterministic(expression, this.metadata) || !set.contains(comparisonExpression.getOperator())) {
                return false;
            }
            Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(comparisonExpression.getLeft());
            Set<Symbol> extractUnique2 = SymbolsExtractor.extractUnique(comparisonExpression.getRight());
            if (extractUnique.isEmpty() || extractUnique2.isEmpty()) {
                return false;
            }
            return (collection.containsAll(extractUnique) && collection2.containsAll(extractUnique2)) || (collection2.containsAll(extractUnique) && collection.containsAll(extractUnique2));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            return !ExpressionUtils.extractConjuncts(rewriteContext.get()).contains(semiJoinNode.getSemiJoinOutput().toSymbolReference()) ? visitNonFilteringSemiJoin(semiJoinNode, rewriteContext) : visitFilteringSemiJoin(semiJoinNode, rewriteContext);
        }

        private PlanNode visitNonFilteringSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression = rewriteContext.get();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            PlanNode defaultRewrite = rewriteContext.defaultRewrite(semiJoinNode.getFilteringSource(), BooleanLiteral.TRUE_LITERAL);
            Set<Symbol> copyOf = ImmutableSet.copyOf(semiJoinNode.getSource().getOutputSymbols());
            EqualityInference equalityInference = new EqualityInference(this.metadata, expression);
            EqualityInference.nonInferrableConjuncts(this.metadata, expression).forEach(expression2 -> {
                Expression rewrite = equalityInference.rewrite(expression2, copyOf);
                if (rewrite != null) {
                    arrayList.add(rewrite);
                } else {
                    arrayList2.add(expression2);
                }
            });
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = equalityInference.generateEqualitiesPartitionedBy(copyOf);
            arrayList.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            PlanNode rewrite = rewriteContext.rewrite(semiJoinNode.getSource(), ExpressionUtils.combineConjuncts(this.metadata, arrayList));
            PlanNode planNode = semiJoinNode;
            if (rewrite != semiJoinNode.getSource() || defaultRewrite != semiJoinNode.getFilteringSource()) {
                planNode = new SemiJoinNode(semiJoinNode.getId(), rewrite, defaultRewrite, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType(), Optional.empty());
            }
            if (!arrayList2.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(this.metadata, arrayList2));
            }
            return planNode;
        }

        private PlanNode visitFilteringSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression = rewriteContext.get();
            Expression filterDeterministicConjuncts = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression);
            Expression filterDeterministicConjuncts2 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, this.effectivePredicateExtractor.extract(this.session, semiJoinNode.getSource(), this.types, this.typeAnalyzer));
            Expression filterDeterministicConjuncts3 = ExpressionUtils.filterDeterministicConjuncts(this.metadata, this.effectivePredicateExtractor.extract(this.session, semiJoinNode.getFilteringSource(), this.types, this.typeAnalyzer));
            Expression comparisonExpression = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, semiJoinNode.getSourceJoinSymbol().toSymbolReference(), semiJoinNode.getFilteringSourceJoinSymbol().toSymbolReference());
            List<Symbol> outputSymbols = semiJoinNode.getSource().getOutputSymbols();
            List<Symbol> outputSymbols2 = semiJoinNode.getFilteringSource().getOutputSymbols();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            EqualityInference equalityInference = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts2, filterDeterministicConjuncts3, comparisonExpression);
            EqualityInference equalityInference2 = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts3, comparisonExpression);
            EqualityInference equalityInference3 = new EqualityInference(this.metadata, filterDeterministicConjuncts, filterDeterministicConjuncts2, comparisonExpression);
            ImmutableSet copyOf = ImmutableSet.copyOf(outputSymbols);
            EqualityInference.nonInferrableConjuncts(this.metadata, expression).forEach(expression2 -> {
                Expression rewrite = equalityInference.rewrite(expression2, copyOf);
                if (rewrite != null) {
                    arrayList.add(rewrite);
                } else {
                    arrayList3.add(expression2);
                }
            });
            ImmutableSet copyOf2 = ImmutableSet.copyOf(outputSymbols2);
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts).forEach(expression3 -> {
                Expression rewrite = equalityInference.rewrite(expression3, copyOf2);
                if (rewrite != null) {
                    arrayList2.add(rewrite);
                }
            });
            Stream filter = EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts3).map(expression4 -> {
                return equalityInference.rewrite(expression4, copyOf);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(arrayList);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            Stream filter2 = EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts2).map(expression5 -> {
                return equalityInference.rewrite(expression5, copyOf2);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            Objects.requireNonNull(arrayList2);
            filter2.forEach((v1) -> {
                r1.add(v1);
            });
            arrayList.addAll(equalityInference2.generateEqualitiesPartitionedBy(copyOf).getScopeEqualities());
            arrayList2.addAll(equalityInference3.generateEqualitiesPartitionedBy(copyOf2).getScopeEqualities());
            Optional<DynamicFilterId> dynamicFilterId = semiJoinNode.getDynamicFilterId();
            if (dynamicFilterId.isEmpty() && SystemSessionProperties.isEnableDynamicFiltering(this.session) && this.dynamicFiltering) {
                dynamicFilterId = Optional.of(new DynamicFilterId("df_" + this.idAllocator.getNextId().toString()));
                Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol();
                arrayList.add(DynamicFilters.createDynamicFilterExpression(this.metadata, dynamicFilterId.get(), this.symbolAllocator.getTypes().get(sourceJoinSymbol), sourceJoinSymbol.toSymbolReference(), ComparisonExpression.Operator.EQUAL));
            }
            PlanNode rewrite = rewriteContext.rewrite(semiJoinNode.getSource(), ExpressionUtils.combineConjuncts(this.metadata, arrayList));
            PlanNode rewrite2 = rewriteContext.rewrite(semiJoinNode.getFilteringSource(), ExpressionUtils.combineConjuncts(this.metadata, arrayList2));
            PlanNode planNode = semiJoinNode;
            if (rewrite != semiJoinNode.getSource() || rewrite2 != semiJoinNode.getFilteringSource() || !dynamicFilterId.equals(semiJoinNode.getDynamicFilterId())) {
                planNode = new SemiJoinNode(semiJoinNode.getId(), rewrite, rewrite2, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType(), dynamicFilterId);
            }
            if (!arrayList3.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(this.metadata, arrayList3));
            }
            return planNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            if (aggregationNode.hasEmptyGroupingSet()) {
                return visitPlan((PlanNode) aggregationNode, rewriteContext);
            }
            Expression expression = rewriteContext.get();
            EqualityInference equalityInference = new EqualityInference(this.metadata, expression);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            Stream<Expression> filter = ExpressionUtils.extractConjuncts(expression).stream().filter(expression2 -> {
                return !DeterminismEvaluator.isDeterministic(expression2, this.metadata);
            });
            Objects.requireNonNull(arrayList2);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            Expression filterDeterministicConjuncts = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression);
            ImmutableSet copyOf = ImmutableSet.copyOf(aggregationNode.getGroupingKeys());
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts).forEach(expression3 -> {
                if (aggregationNode.getGroupIdSymbol().isPresent() && SymbolsExtractor.extractUnique(expression3).contains(aggregationNode.getGroupIdSymbol().get())) {
                    arrayList2.add(expression3);
                    return;
                }
                Expression rewrite = equalityInference.rewrite(expression3, copyOf);
                if (rewrite != null) {
                    arrayList.add(rewrite);
                } else {
                    arrayList2.add(expression3);
                }
            });
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = equalityInference.generateEqualitiesPartitionedBy(copyOf);
            arrayList.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            PlanNode rewrite = rewriteContext.rewrite(aggregationNode.getSource(), ExpressionUtils.combineConjuncts(this.metadata, arrayList));
            PlanNode planNode = aggregationNode;
            if (rewrite != aggregationNode.getSource()) {
                planNode = AggregationNode.builderFrom(aggregationNode).setSource(rewrite).setPreGroupedSymbols(ImmutableList.of()).build();
            }
            if (!arrayList2.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(this.metadata, arrayList2));
            }
            return planNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitUnnest(UnnestNode unnestNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression expression = rewriteContext.get();
            if (unnestNode.getJoinType() == JoinNode.Type.RIGHT || unnestNode.getJoinType() == JoinNode.Type.FULL) {
                return new FilterNode(this.idAllocator.getNextId(), unnestNode, expression);
            }
            EqualityInference equalityInference = new EqualityInference(this.metadata, expression);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            Stream<Expression> filter = ExpressionUtils.extractConjuncts(expression).stream().filter(expression2 -> {
                return !DeterminismEvaluator.isDeterministic(expression2, this.metadata);
            });
            Objects.requireNonNull(arrayList2);
            filter.forEach((v1) -> {
                r1.add(v1);
            });
            Expression filterDeterministicConjuncts = ExpressionUtils.filterDeterministicConjuncts(this.metadata, expression);
            ImmutableSet copyOf = ImmutableSet.copyOf(unnestNode.getReplicateSymbols());
            EqualityInference.nonInferrableConjuncts(this.metadata, filterDeterministicConjuncts).forEach(expression3 -> {
                Expression rewrite = equalityInference.rewrite(expression3, copyOf);
                if (rewrite != null) {
                    arrayList.add(rewrite);
                } else {
                    arrayList2.add(expression3);
                }
            });
            EqualityInference.EqualityPartition generateEqualitiesPartitionedBy = equalityInference.generateEqualitiesPartitionedBy(copyOf);
            arrayList.addAll(generateEqualitiesPartitionedBy.getScopeEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeComplementEqualities());
            arrayList2.addAll(generateEqualitiesPartitionedBy.getScopeStraddlingEqualities());
            PlanNode rewrite = rewriteContext.rewrite(unnestNode.getSource(), ExpressionUtils.combineConjuncts(this.metadata, arrayList));
            PlanNode planNode = unnestNode;
            if (rewrite != unnestNode.getSource()) {
                planNode = new UnnestNode(unnestNode.getId(), rewrite, unnestNode.getReplicateSymbols(), unnestNode.getMappings(), unnestNode.getOrdinalitySymbol(), unnestNode.getJoinType(), unnestNode.getFilter());
            }
            if (!arrayList2.isEmpty()) {
                planNode = new FilterNode(this.idAllocator.getNextId(), planNode, ExpressionUtils.combineConjuncts(this.metadata, arrayList2));
            }
            return planNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitSample(SampleNode sampleNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            return rewriteContext.defaultRewrite(sampleNode, rewriteContext.get());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitTableScan(TableScanNode tableScanNode, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Expression simplifyExpression = simplifyExpression(rewriteContext.get());
            return !BooleanLiteral.TRUE_LITERAL.equals(simplifyExpression) ? new FilterNode(this.idAllocator.getNextId(), tableScanNode, simplifyExpression) : tableScanNode;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAssignUniqueId(AssignUniqueId assignUniqueId, SimplePlanRewriter.RewriteContext<Expression> rewriteContext) {
            Preconditions.checkState(!SymbolsExtractor.extractUnique(rewriteContext.get()).contains(assignUniqueId.getIdColumn()), "UniqueId in predicate is not yet supported");
            return rewriteContext.defaultRewrite(assignUniqueId, rewriteContext.get());
        }
    }

    public PredicatePushDown(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, boolean z, boolean z2) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        this.useTableProperties = z;
        this.dynamicFiltering = z2;
    }

    @Override // io.trino.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, TableStatsProvider tableStatsProvider) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, planNodeIdAllocator, this.plannerContext, this.typeAnalyzer, session, typeProvider, this.useTableProperties, this.dynamicFiltering), planNode, BooleanLiteral.TRUE_LITERAL);
    }
}
