package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.OrderBy;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.SortItem;
import io.trino.sql.tree.SymbolReference;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.class */
public class ExpressionRewriteRuleSet {
    private final ExpressionRewriter rewriter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet$AggregationExpressionRewrite.class */
    public static final class AggregationExpressionRewrite implements Rule<AggregationNode> {
        private final ExpressionRewriter rewriter;

        AggregationExpressionRewrite(ExpressionRewriter expressionRewriter) {
            this.rewriter = expressionRewriter;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<AggregationNode> getPattern() {
            return Patterns.aggregation();
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
            boolean z = false;
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
                AggregationNode.Aggregation value = entry.getValue();
                FunctionCall rewrite = this.rewriter.rewrite(new FunctionCall(Optional.empty(), QualifiedName.of(value.getResolvedFunction().getSignature().getName()), Optional.empty(), value.getFilter().map(symbol -> {
                    return new SymbolReference(symbol.getName());
                }), value.getOrderingScheme().map(orderingScheme -> {
                    return new OrderBy((List) orderingScheme.getOrderBy().stream().map(symbol2 -> {
                        return new SortItem(new SymbolReference(symbol2.getName()), orderingScheme.getOrdering(symbol2).isAscending() ? SortItem.Ordering.ASCENDING : SortItem.Ordering.DESCENDING, orderingScheme.getOrdering(symbol2).isNullsFirst() ? SortItem.NullOrdering.FIRST : SortItem.NullOrdering.LAST);
                    }).collect(ImmutableList.toImmutableList()));
                }), value.isDistinct(), Optional.empty(), Optional.empty(), value.getArguments()), context);
                Verify.verify(QualifiedName.of(ResolvedFunction.extractFunctionName(rewrite.getName())).equals(QualifiedName.of(value.getResolvedFunction().getSignature().getName())), "Aggregation function name changed", new Object[0]);
                AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(value.getResolvedFunction(), rewrite.getArguments(), rewrite.isDistinct(), rewrite.getFilter().map(Symbol::from), rewrite.getOrderBy().map(OrderingScheme::fromOrderBy), value.getMask());
                builder.put(entry.getKey(), aggregation);
                if (!value.equals(aggregation)) {
                    z = true;
                }
            }
            return z ? Rule.Result.ofPlanNode(new AggregationNode(aggregationNode.getId(), aggregationNode.getSource(), builder.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol())) : Rule.Result.empty();
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet$ExpressionRewriter.class */
    public interface ExpressionRewriter {
        Expression rewrite(Expression expression, Rule.Context context);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet$FilterExpressionRewrite.class */
    public static final class FilterExpressionRewrite implements Rule<FilterNode> {
        private final ExpressionRewriter rewriter;

        FilterExpressionRewrite(ExpressionRewriter expressionRewriter) {
            this.rewriter = expressionRewriter;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<FilterNode> getPattern() {
            return Patterns.filter();
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            Expression rewrite = this.rewriter.rewrite(filterNode.getPredicate(), context);
            return filterNode.getPredicate().equals(rewrite) ? Rule.Result.empty() : Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), filterNode.getSource(), rewrite));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet$JoinExpressionRewrite.class */
    public static final class JoinExpressionRewrite implements Rule<JoinNode> {
        private final ExpressionRewriter rewriter;

        JoinExpressionRewrite(ExpressionRewriter expressionRewriter) {
            this.rewriter = expressionRewriter;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<JoinNode> getPattern() {
            return Patterns.join();
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            Optional<U> map = joinNode.getFilter().map(expression -> {
                return this.rewriter.rewrite(expression, context);
            });
            return !joinNode.getFilter().equals(map) ? Rule.Result.ofPlanNode(new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), map, joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost())) : Rule.Result.empty();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet$ProjectExpressionRewrite.class */
    public static final class ProjectExpressionRewrite implements Rule<ProjectNode> {
        private final ExpressionRewriter rewriter;

        ProjectExpressionRewrite(ExpressionRewriter expressionRewriter) {
            this.rewriter = expressionRewriter;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<ProjectNode> getPattern() {
            return Patterns.project();
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
            Assignments rewrite = projectNode.getAssignments().rewrite(expression -> {
                return this.rewriter.rewrite(expression, context);
            });
            return projectNode.getAssignments().equals(rewrite) ? Rule.Result.empty() : Rule.Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), rewrite));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet$ValuesExpressionRewrite.class */
    public static final class ValuesExpressionRewrite implements Rule<ValuesNode> {
        private final ExpressionRewriter rewriter;

        ValuesExpressionRewrite(ExpressionRewriter expressionRewriter) {
            this.rewriter = expressionRewriter;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<ValuesNode> getPattern() {
            return Patterns.values();
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(ValuesNode valuesNode, Captures captures, Rule.Context context) {
            if (valuesNode.getRows().isEmpty()) {
                return Rule.Result.empty();
            }
            boolean z = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            Iterator<Expression> it = valuesNode.getRows().get().iterator();
            while (it.hasNext()) {
                Row row = (Expression) it.next();
                Row row2 = row instanceof Row ? new Row((List) row.getItems().stream().map(expression -> {
                    return this.rewriter.rewrite(expression, context);
                }).collect(ImmutableList.toImmutableList())) : this.rewriter.rewrite(row, context);
                if (!row.equals(row2)) {
                    z = true;
                }
                builder.add(row2);
            }
            return z ? Rule.Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputSymbols(), builder.build())) : Rule.Result.empty();
        }
    }

    public ExpressionRewriteRuleSet(ExpressionRewriter expressionRewriter) {
        this.rewriter = (ExpressionRewriter) Objects.requireNonNull(expressionRewriter, "rewriter is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of(projectExpressionRewrite(), aggregationExpressionRewrite(), filterExpressionRewrite(), joinExpressionRewrite(), valuesExpressionRewrite());
    }

    public Rule<?> projectExpressionRewrite() {
        return new ProjectExpressionRewrite(this.rewriter);
    }

    public Rule<?> aggregationExpressionRewrite() {
        return new AggregationExpressionRewrite(this.rewriter);
    }

    public Rule<?> filterExpressionRewrite() {
        return new FilterExpressionRewrite(this.rewriter);
    }

    public Rule<?> joinExpressionRewrite() {
        return new JoinExpressionRewrite(this.rewriter);
    }

    public Rule<?> valuesExpressionRewrite() {
        return new ValuesExpressionRewrite(this.rewriter);
    }
}
