package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.class */
public class TransformFilteringSemiJoinToInnerJoin implements Rule<FilterNode> {
    private static final Capture<SemiJoinNode> SEMI_JOIN = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.semiJoin().capturedAs(SEMI_JOIN)));

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isRewriteFilteringSemiJoinToInnerJoin(session);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        SemiJoinNode semiJoinNode = (SemiJoinNode) captures.get(SEMI_JOIN);
        if (PlanNodeSearcher.searchFrom(semiJoinNode.getSource(), context.getLookup()).where(planNode -> {
            return (planNode instanceof TableScanNode) && ((TableScanNode) planNode).isUpdateTarget();
        }).matches()) {
            return Rule.Result.empty();
        }
        Symbol semiJoinOutput = semiJoinNode.getSemiJoinOutput();
        Predicate<? super Expression> predicate = expression -> {
            return expression.equals(semiJoinOutput.toSymbolReference());
        };
        List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(filterNode.getPredicate());
        if (extractConjuncts.stream().noneMatch(predicate)) {
            return Rule.Result.empty();
        }
        Expression inlineSymbols = ExpressionSymbolInliner.inlineSymbols((Function<Symbol, Expression>) symbol -> {
            return symbol.equals(semiJoinOutput) ? BooleanLiteral.TRUE_LITERAL : symbol.toSymbolReference();
        }, ExpressionUtils.and((Collection<Expression>) extractConjuncts.stream().filter(Predicate.not(predicate)).collect(ImmutableList.toImmutableList())));
        JoinNode joinNode = new JoinNode(semiJoinNode.getId(), JoinNode.Type.INNER, semiJoinNode.getSource(), new AggregationNode(context.getIdAllocator().getNextId(), semiJoinNode.getFilteringSource(), ImmutableMap.of(), AggregationNode.singleGroupingSet(ImmutableList.of(semiJoinNode.getFilteringSourceJoinSymbol())), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), ImmutableList.of(new JoinNode.EquiJoinClause(semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol())), semiJoinNode.getSource().getOutputSymbols(), ImmutableList.of(), false, inlineSymbols.equals(BooleanLiteral.TRUE_LITERAL) ? Optional.empty() : Optional.of(inlineSymbols), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map) semiJoinNode.getDynamicFilterId().map(dynamicFilterId -> {
            return ImmutableMap.of(dynamicFilterId, semiJoinNode.getFilteringSourceJoinSymbol());
        }).orElse(ImmutableMap.of()), Optional.empty());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), joinNode, Assignments.builder().putIdentities(joinNode.getOutputSymbols()).put(semiJoinOutput, BooleanLiteral.TRUE_LITERAL).build()));
    }
}
