/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.IfExpression;
import io.prestosql.sql.tree.NullLiteral;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public class TransformUncorrelatedSubqueryToJoin
implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.empty(Patterns.CorrelatedJoin.correlation()));

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        if (correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER || correlatedJoinNode.getType() == CorrelatedJoinNode.Type.LEFT) {
            return Rule.Result.ofPlanNode(this.rewriteToJoin(correlatedJoinNode, correlatedJoinNode.getType().toJoinNodeType(), correlatedJoinNode.getFilter()));
        }
        Preconditions.checkState((correlatedJoinNode.getType() == CorrelatedJoinNode.Type.RIGHT || correlatedJoinNode.getType() == CorrelatedJoinNode.Type.FULL ? 1 : 0) != 0, (Object)("unexpected CorrelatedJoin type: " + correlatedJoinNode.getType()));
        JoinNode.Type type = correlatedJoinNode.getType() == CorrelatedJoinNode.Type.RIGHT ? JoinNode.Type.INNER : JoinNode.Type.LEFT;
        JoinNode joinNode = this.rewriteToJoin(correlatedJoinNode, type, (Expression)BooleanLiteral.TRUE_LITERAL);
        if (correlatedJoinNode.getFilter().equals((Object)BooleanLiteral.TRUE_LITERAL)) {
            return Rule.Result.ofPlanNode(joinNode);
        }
        if (correlatedJoinNode.getType() == CorrelatedJoinNode.Type.RIGHT) {
            Assignments.Builder assignments = Assignments.builder();
            assignments.putIdentities((Iterable<Symbol>)Sets.intersection((Set)ImmutableSet.copyOf(correlatedJoinNode.getSubquery().getOutputSymbols()), (Set)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())));
            for (Symbol inputSymbol : Sets.intersection((Set)ImmutableSet.copyOf(correlatedJoinNode.getInput().getOutputSymbols()), (Set)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()))) {
                assignments.put(inputSymbol, (Expression)new IfExpression(correlatedJoinNode.getFilter(), (Expression)inputSymbol.toSymbolReference(), (Expression)new NullLiteral()));
            }
            ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), joinNode, assignments.build());
            return Rule.Result.ofPlanNode(projectNode);
        }
        return Rule.Result.empty();
    }

    private JoinNode rewriteToJoin(CorrelatedJoinNode parent, JoinNode.Type type, Expression filter) {
        return new JoinNode(parent.getId(), type, parent.getInput(), parent.getSubquery(), (List<JoinNode.EquiJoinClause>)ImmutableList.of(), parent.getInput().getOutputSymbols(), parent.getSubquery().getOutputSymbols(), filter.equals((Object)BooleanLiteral.TRUE_LITERAL) ? Optional.empty() : Optional.of(filter), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<String, Symbol>)ImmutableMap.of(), Optional.empty());
    }
}

