package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.connector.CatalogServiceProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.AnalyzePropertyManager;
import io.trino.metadata.OperatorNotFoundException;
import io.trino.metadata.SessionPropertyManager;
import io.trino.metadata.TableFunctionRegistry;
import io.trino.metadata.TableProceduresPropertyManager;
import io.trino.metadata.TableProceduresRegistry;
import io.trino.metadata.TablePropertyManager;
import io.trino.operator.join.JoinUtils;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.sql.DynamicFilters;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.StatementAnalyzerFactory;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.SymbolReference;
import io.trino.transaction.NoOpTransactionManager;
import io.trino.type.TypeCoercion;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.class */
public class RemoveUnsupportedDynamicFilters implements PlanOptimizer {
    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$PlanWithConsumedDynamicFilters.class */
    public static class PlanWithConsumedDynamicFilters {
        private final PlanNode node;
        private final Set<DynamicFilterId> consumedDynamicFilterIds;

        PlanWithConsumedDynamicFilters(PlanNode planNode, Set<DynamicFilterId> set) {
            this.node = planNode;
            this.consumedDynamicFilterIds = ImmutableSet.copyOf(set);
        }

        PlanNode getNode() {
            return this.node;
        }

        Set<DynamicFilterId> getConsumedDynamicFilterIds() {
            return this.consumedDynamicFilterIds;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$Rewriter.class */
    public class Rewriter extends PlanVisitor<PlanWithConsumedDynamicFilters, Set<DynamicFilterId>> {
        private final Session session;
        private final TypeProvider types;
        private final TypeCoercion typeCoercion;

        public Rewriter(Session session, TypeProvider typeProvider) {
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            TypeManager typeManager = RemoveUnsupportedDynamicFilters.this.plannerContext.getTypeManager();
            Objects.requireNonNull(typeManager);
            this.typeCoercion = new TypeCoercion(typeManager::getType);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitPlan(PlanNode planNode, Set<DynamicFilterId> set) {
            List list = (List) planNode.getSources().stream().map(planNode2 -> {
                return (PlanWithConsumedDynamicFilters) planNode2.accept(this, set);
            }).collect(ImmutableList.toImmutableList());
            return new PlanWithConsumedDynamicFilters(ChildReplacer.replaceChildren(planNode, (List) list.stream().map((v0) -> {
                return v0.getNode();
            }).collect(Collectors.toList())), (Set) list.stream().map((v0) -> {
                return v0.getConsumedDynamicFilterIds();
            }).flatMap((v0) -> {
                return v0.stream();
            }).collect(ImmutableSet.toImmutableSet()));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitJoin(JoinNode joinNode, Set<DynamicFilterId> set) {
            Map<DynamicFilterId, Symbol> joinDynamicFilters = JoinUtils.getJoinDynamicFilters(joinNode);
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) joinNode.getLeft().accept(this, ImmutableSet.builder().addAll(joinDynamicFilters.keySet()).addAll(set).build());
            Set<DynamicFilterId> consumedDynamicFilterIds = planWithConsumedDynamicFilters.getConsumedDynamicFilterIds();
            ImmutableMap immutableMap = (Map) joinDynamicFilters.entrySet().stream().filter(entry -> {
                return consumedDynamicFilterIds.contains(entry.getKey());
            }).collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters2 = (PlanWithConsumedDynamicFilters) joinNode.getRight().accept(this, set);
            HashSet hashSet = new HashSet(planWithConsumedDynamicFilters2.getConsumedDynamicFilterIds());
            hashSet.addAll(consumedDynamicFilterIds);
            hashSet.removeAll(immutableMap.keySet());
            Optional filter = joinNode.getFilter().map(this::removeAllDynamicFilters).filter(expression -> {
                return !expression.equals(BooleanLiteral.TRUE_LITERAL);
            });
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            PlanNode node2 = planWithConsumedDynamicFilters2.getNode();
            if (node.equals(joinNode.getLeft()) && node2.equals(joinNode.getRight()) && immutableMap.equals(joinDynamicFilters) && filter.equals(joinNode.getFilter())) {
                return new PlanWithConsumedDynamicFilters(joinNode, ImmutableSet.copyOf(hashSet));
            }
            return new PlanWithConsumedDynamicFilters(new JoinNode(joinNode.getId(), joinNode.getType(), node, node2, joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), filter, joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters().isEmpty() ? ImmutableMap.of() : immutableMap, joinNode.getReorderJoinStatsAndCost()), ImmutableSet.copyOf(hashSet));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitSpatialJoin(SpatialJoinNode spatialJoinNode, Set<DynamicFilterId> set) {
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) spatialJoinNode.getLeft().accept(this, set);
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters2 = (PlanWithConsumedDynamicFilters) spatialJoinNode.getRight().accept(this, set);
            ImmutableSet build = ImmutableSet.builder().addAll(planWithConsumedDynamicFilters.consumedDynamicFilterIds).addAll(planWithConsumedDynamicFilters2.consumedDynamicFilterIds).build();
            Expression removeAllDynamicFilters = removeAllDynamicFilters(spatialJoinNode.getFilter());
            return (spatialJoinNode.getFilter().equals(removeAllDynamicFilters) && planWithConsumedDynamicFilters.getNode() == spatialJoinNode.getLeft() && planWithConsumedDynamicFilters2.getNode() == spatialJoinNode.getRight()) ? new PlanWithConsumedDynamicFilters(spatialJoinNode, build) : new PlanWithConsumedDynamicFilters(new SpatialJoinNode(spatialJoinNode.getId(), spatialJoinNode.getType(), planWithConsumedDynamicFilters.getNode(), planWithConsumedDynamicFilters2.getNode(), spatialJoinNode.getOutputSymbols(), removeAllDynamicFilters, spatialJoinNode.getLeftPartitionSymbol(), spatialJoinNode.getRightPartitionSymbol(), spatialJoinNode.getKdbTree()), build);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitSemiJoin(SemiJoinNode semiJoinNode, Set<DynamicFilterId> set) {
            Optional empty;
            Optional<DynamicFilterId> semiJoinDynamicFilterId = JoinUtils.getSemiJoinDynamicFilterId(semiJoinNode);
            if (semiJoinDynamicFilterId.isEmpty()) {
                return visitPlan((PlanNode) semiJoinNode, set);
            }
            DynamicFilterId dynamicFilterId = semiJoinDynamicFilterId.get();
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) semiJoinNode.getSource().accept(this, ImmutableSet.builder().add(dynamicFilterId).addAll(set).build());
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters2 = (PlanWithConsumedDynamicFilters) semiJoinNode.getFilteringSource().accept(this, set);
            HashSet hashSet = new HashSet(planWithConsumedDynamicFilters2.getConsumedDynamicFilterIds());
            hashSet.addAll(planWithConsumedDynamicFilters.getConsumedDynamicFilterIds());
            if (hashSet.contains(dynamicFilterId)) {
                hashSet.remove(dynamicFilterId);
                empty = Optional.of(dynamicFilterId);
            } else {
                empty = Optional.empty();
            }
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            PlanNode node2 = planWithConsumedDynamicFilters2.getNode();
            if (node.equals(semiJoinNode.getSource()) && node2.equals(semiJoinNode.getFilteringSource()) && empty.equals(semiJoinDynamicFilterId)) {
                return new PlanWithConsumedDynamicFilters(semiJoinNode, ImmutableSet.copyOf(hashSet));
            }
            return new PlanWithConsumedDynamicFilters(new SemiJoinNode(semiJoinNode.getId(), node, node2, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType(), semiJoinNode.getDynamicFilterId().isEmpty() ? Optional.empty() : empty), ImmutableSet.copyOf(hashSet));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitFilter(FilterNode filterNode, Set<DynamicFilterId> set) {
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) filterNode.getSource().accept(this, set);
            Expression predicate = filterNode.getPredicate();
            ImmutableSet.Builder<DynamicFilterId> addAll = ImmutableSet.builder().addAll(planWithConsumedDynamicFilters.getConsumedDynamicFilterIds());
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            Expression removeDynamicFilters = node instanceof TableScanNode ? removeDynamicFilters(predicate, set, addAll) : removeAllDynamicFilters(predicate);
            return BooleanLiteral.TRUE_LITERAL.equals(removeDynamicFilters) ? new PlanWithConsumedDynamicFilters(node, addAll.build()) : (predicate.equals(removeDynamicFilters) && node == filterNode.getSource()) ? new PlanWithConsumedDynamicFilters(filterNode, addAll.build()) : new PlanWithConsumedDynamicFilters(new FilterNode(filterNode.getId(), node, removeDynamicFilters), addAll.build());
        }

        private Expression removeDynamicFilters(Expression expression, Set<DynamicFilterId> set, ImmutableSet.Builder<DynamicFilterId> builder) {
            return ExpressionUtils.combineConjuncts(RemoveUnsupportedDynamicFilters.this.plannerContext.getMetadata(), (Collection<Expression>) ExpressionUtils.extractConjuncts(expression).stream().map(this::removeNestedDynamicFilters).filter(expression2 -> {
                return ((Boolean) DynamicFilters.getDescriptor(expression2).map(descriptor -> {
                    if (!set.contains(descriptor.getId()) || !isSupportedDynamicFilterExpression(descriptor.getInput())) {
                        return false;
                    }
                    builder.add(descriptor.getId());
                    return true;
                }).orElse(true)).booleanValue();
            }).collect(ImmutableList.toImmutableList()));
        }

        private boolean isSupportedDynamicFilterExpression(Expression expression) {
            if (expression instanceof SymbolReference) {
                return true;
            }
            if (!(expression instanceof Cast)) {
                return false;
            }
            Cast cast = (Cast) expression;
            if (!(cast.getExpression() instanceof SymbolReference)) {
                return false;
            }
            Map<NodeRef<Expression>, Type> types = RemoveUnsupportedDynamicFilters.this.typeAnalyzer.getTypes(this.session, this.types, expression);
            Type type = types.get(NodeRef.of(cast.getExpression()));
            Type type2 = types.get(NodeRef.of(cast));
            if (this.typeCoercion.canCoerce(type, type2)) {
                return doesSaturatedFloorCastOperatorExist(type2, type);
            }
            return false;
        }

        private boolean doesSaturatedFloorCastOperatorExist(Type type, Type type2) {
            try {
                RemoveUnsupportedDynamicFilters.this.plannerContext.getMetadata().getCoercion(this.session, OperatorType.SATURATED_FLOOR_CAST, type, type2);
                return true;
            } catch (OperatorNotFoundException e) {
                return false;
            }
        }

        private Expression removeAllDynamicFilters(Expression expression) {
            Expression removeNestedDynamicFilters = removeNestedDynamicFilters(expression);
            DynamicFilters.ExtractResult extractDynamicFilters = DynamicFilters.extractDynamicFilters(removeNestedDynamicFilters);
            return extractDynamicFilters.getDynamicConjuncts().isEmpty() ? removeNestedDynamicFilters : ExpressionUtils.combineConjuncts(RemoveUnsupportedDynamicFilters.this.plannerContext.getMetadata(), extractDynamicFilters.getStaticConjuncts());
        }

        private Expression removeNestedDynamicFilters(Expression expression) {
            return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { // from class: io.trino.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters.Rewriter.1
                public Expression rewriteLogicalExpression(LogicalExpression logicalExpression, Void r6, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                    LogicalExpression logicalExpression2 = (LogicalExpression) expressionTreeRewriter.defaultRewrite(logicalExpression, r6);
                    boolean z = logicalExpression != logicalExpression2;
                    ImmutableList.Builder builder = ImmutableList.builder();
                    for (Expression expression2 : logicalExpression2.getTerms()) {
                        if (DynamicFilters.isDynamicFilter(expression2)) {
                            builder.add(BooleanLiteral.TRUE_LITERAL);
                            z = true;
                        } else {
                            builder.add(expression2);
                        }
                    }
                    return !z ? logicalExpression : ExpressionUtils.combinePredicates(RemoveUnsupportedDynamicFilters.this.plannerContext.getMetadata(), logicalExpression.getOperator(), (Collection<Expression>) builder.build());
                }

                public /* bridge */ /* synthetic */ Expression rewriteLogicalExpression(LogicalExpression logicalExpression, Object obj, ExpressionTreeRewriter expressionTreeRewriter) {
                    return rewriteLogicalExpression(logicalExpression, (Void) obj, (ExpressionTreeRewriter<Void>) expressionTreeRewriter);
                }
            }, expression);
        }
    }

    public RemoveUnsupportedDynamicFilters(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = new TypeAnalyzer(plannerContext, new StatementAnalyzerFactory(plannerContext, new SqlParser(), new AllowAllAccessControl(), new NoOpTransactionManager(), str -> {
            return ImmutableSet.of();
        }, new TableProceduresRegistry(CatalogServiceProvider.fail("procedures are not supported in testing analyzer")), new TableFunctionRegistry(CatalogServiceProvider.fail("table functions are not supported in testing analyzer")), new SessionPropertyManager(), new TablePropertyManager(CatalogServiceProvider.fail("table properties not supported in testing analyzer")), new AnalyzePropertyManager(CatalogServiceProvider.fail("analyze properties not supported in testing analyzer")), new TableProceduresPropertyManager(CatalogServiceProvider.fail("procedures are not supported in testing analyzer"))));
    }

    @Override // io.trino.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) {
        return ((PlanWithConsumedDynamicFilters) planNode.accept(new Rewriter(session, typeProvider), ImmutableSet.of())).getNode();
    }
}
