/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.ScalarAggregationToJoinRewriter;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

public class TransformCorrelatedScalarAggregationToJoin {
    private final Metadata metadata;

    public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new TransformCorrelatedScalarAggregationWithProjection(this.metadata), (Object)new TransformCorrelatedScalarAggregationWithoutProjection(this.metadata));
    }

    @VisibleForTesting
    static final class TransformCorrelatedScalarAggregationWithoutProjection
    implements Rule<CorrelatedJoinNode> {
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).capturedAs(AGGREGATION)));
        private final Metadata metadata;

        @VisibleForTesting
        TransformCorrelatedScalarAggregationWithoutProjection(Metadata metadata) {
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public Pattern<CorrelatedJoinNode> getPattern() {
            return PATTERN;
        }

        @Override
        public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
            PlanNode rewrittenNode = new ScalarAggregationToJoinRewriter(this.metadata, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()).rewriteScalarAggregation(correlatedJoinNode, (AggregationNode)captures.get(AGGREGATION));
            if (rewrittenNode instanceof CorrelatedJoinNode) {
                return Rule.Result.empty();
            }
            HashSet<Symbol> outputSymbols = new HashSet<Symbol>(correlatedJoinNode.getOutputSymbols());
            List expectedAggregationOutputs = (List)rewrittenNode.getOutputSymbols().stream().filter(outputSymbols::contains).collect(ImmutableList.toImmutableList());
            return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rewrittenNode, Assignments.identity(expectedAggregationOutputs)));
        }
    }

    @VisibleForTesting
    static final class TransformCorrelatedScalarAggregationWithProjection
    implements Rule<CorrelatedJoinNode> {
        private static final Capture<ProjectNode> PROJECTION = Capture.newCapture();
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.project().capturedAs(PROJECTION).with(Patterns.source().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).capturedAs(AGGREGATION)))));
        private final Metadata metadata;

        @VisibleForTesting
        TransformCorrelatedScalarAggregationWithProjection(Metadata metadata) {
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public Pattern<CorrelatedJoinNode> getPattern() {
            return PATTERN;
        }

        @Override
        public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
            PlanNode rewrittenNode = new ScalarAggregationToJoinRewriter(this.metadata, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()).rewriteScalarAggregation(correlatedJoinNode, (AggregationNode)captures.get(AGGREGATION));
            if (rewrittenNode instanceof CorrelatedJoinNode) {
                return Rule.Result.empty();
            }
            HashSet<Symbol> outputSymbols = new HashSet<Symbol>(correlatedJoinNode.getOutputSymbols());
            List expectedAggregationOutputs = (List)rewrittenNode.getOutputSymbols().stream().filter(outputSymbols::contains).collect(ImmutableList.toImmutableList());
            Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(((ProjectNode)captures.get(PROJECTION)).getAssignments()).build();
            return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rewrittenNode, assignments));
        }
    }
}

