package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SetOperationNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FrameBound;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WindowFrame;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.class */
public class SetOperationNodeTranslator {
    private static final String MARKER = "marker";
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final ResolvedFunction countFunction;
    private final ResolvedFunction rowNumberFunction;

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator$TranslationResult.class */
    public static class TranslationResult {
        private final PlanNode planNode;
        private final List<Symbol> countSymbols;
        private final Optional<Symbol> rowNumberSymbol;

        public TranslationResult(PlanNode planNode, List<Symbol> list) {
            this(planNode, list, Optional.empty());
        }

        public TranslationResult(PlanNode planNode, List<Symbol> list, Optional<Symbol> optional) {
            this.planNode = (PlanNode) Objects.requireNonNull(planNode, "planNode is null");
            this.countSymbols = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "countSymbols is null"));
            this.rowNumberSymbol = (Optional) Objects.requireNonNull(optional, "rowNumberSymbol is null");
        }

        public PlanNode getPlanNode() {
            return this.planNode;
        }

        public List<Symbol> getCountSymbols() {
            return this.countSymbols;
        }

        public Symbol getRowNumberSymbol() {
            Preconditions.checkState(this.rowNumberSymbol.isPresent(), "rowNumberSymbol is empty");
            return this.rowNumberSymbol.get();
        }
    }

    public SetOperationNodeTranslator(Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "SymbolAllocator is null");
        this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        Objects.requireNonNull(metadata, "metadata is null");
        this.countFunction = metadata.resolveFunction(QualifiedName.of("count"), TypeSignatureProvider.fromTypes(BooleanType.BOOLEAN));
        this.rowNumberFunction = metadata.resolveFunction(QualifiedName.of("row_number"), ImmutableList.of());
    }

    public TranslationResult makeSetContainmentPlanForDistinct(SetOperationNode setOperationNode) {
        Preconditions.checkArgument(!(setOperationNode instanceof UnionNode), "Cannot simplify a UnionNode");
        List<Symbol> allocateSymbols = allocateSymbols(setOperationNode.getSources().size(), MARKER, BooleanType.BOOLEAN);
        List<PlanNode> appendMarkers = appendMarkers(allocateSymbols, setOperationNode.getSources(), setOperationNode);
        List<Symbol> outputSymbols = setOperationNode.getOutputSymbols();
        UnionNode union = union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputSymbols, allocateSymbols)));
        List<Symbol> allocateSymbols2 = allocateSymbols(allocateSymbols.size(), "count", BigintType.BIGINT);
        return new TranslationResult(computeCounts(union, outputSymbols, allocateSymbols, allocateSymbols2), allocateSymbols2);
    }

    public TranslationResult makeSetContainmentPlanForAll(SetOperationNode setOperationNode) {
        Preconditions.checkArgument(!(setOperationNode instanceof UnionNode), "Cannot simplify a UnionNode");
        List<Symbol> allocateSymbols = allocateSymbols(setOperationNode.getSources().size(), MARKER, BooleanType.BOOLEAN);
        List<PlanNode> appendMarkers = appendMarkers(allocateSymbols, setOperationNode.getSources(), setOperationNode);
        List<Symbol> outputSymbols = setOperationNode.getOutputSymbols();
        UnionNode union = union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputSymbols, allocateSymbols)));
        List<Symbol> allocateSymbols2 = allocateSymbols(allocateSymbols.size(), "count", BigintType.BIGINT);
        Symbol newSymbol = this.symbolAllocator.newSymbol("row_number", (Type) BigintType.BIGINT);
        return new TranslationResult(new ProjectNode(this.idAllocator.getNextId(), appendCounts(union, outputSymbols, allocateSymbols, allocateSymbols2, newSymbol), Assignments.identity((Iterable<Symbol>) ImmutableList.copyOf(Iterables.concat(outputSymbols, allocateSymbols2, ImmutableList.of(newSymbol))))), allocateSymbols2, Optional.of(newSymbol));
    }

    private List<Symbol> allocateSymbols(int i, String str, Type type) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i2 = 0; i2 < i; i2++) {
            builder.add(this.symbolAllocator.newSymbol(str, type));
        }
        return builder.build();
    }

    private List<PlanNode> appendMarkers(List<Symbol> list, List<PlanNode> list2, SetOperationNode setOperationNode) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < list2.size(); i++) {
            builder.add(appendMarkers(this.idAllocator, this.symbolAllocator, list2.get(i), i, list, setOperationNode.sourceSymbolMap(i)));
        }
        return builder.build();
    }

    private static PlanNode appendMarkers(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, PlanNode planNode, int i, List<Symbol> list, Map<Symbol, SymbolReference> map) {
        Assignments.Builder builder = Assignments.builder();
        for (Map.Entry<Symbol, SymbolReference> entry : map.entrySet()) {
            builder.put(symbolAllocator.newSymbol(entry.getKey().getName(), symbolAllocator.getTypes().get(entry.getKey())), (Expression) entry.getValue());
        }
        int i2 = 0;
        while (i2 < list.size()) {
            builder.put(symbolAllocator.newSymbol(list.get(i2).getName(), (Type) BooleanType.BOOLEAN), i2 == i ? BooleanLiteral.TRUE_LITERAL : new Cast(new NullLiteral(), TypeSignatureTranslator.toSqlType(BooleanType.BOOLEAN)));
            i2++;
        }
        return new ProjectNode(planNodeIdAllocator.getNextId(), planNode, builder.build());
    }

    private UnionNode union(List<PlanNode> list, List<Symbol> list2) {
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        for (PlanNode planNode : list) {
            for (int i = 0; i < planNode.getOutputSymbols().size(); i++) {
                builder.put(list2.get(i), planNode.getOutputSymbols().get(i));
            }
        }
        return new UnionNode(this.idAllocator.getNextId(), list, builder.build(), list2);
    }

    private AggregationNode computeCounts(UnionNode unionNode, List<Symbol> list, List<Symbol> list2, List<Symbol> list3) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i = 0; i < list2.size(); i++) {
            builder.put(list3.get(i), new AggregationNode.Aggregation(this.countFunction, ImmutableList.of(list2.get(i).toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return new AggregationNode(this.idAllocator.getNextId(), unionNode, builder.build(), AggregationNode.singleGroupingSet(list), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }

    private WindowNode appendCounts(UnionNode unionNode, List<Symbol> list, List<Symbol> list2, List<Symbol> list3, Symbol symbol) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        WindowNode.Frame frame = new WindowNode.Frame(WindowFrame.Type.ROWS, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        for (int i = 0; i < list2.size(); i++) {
            builder.put(list3.get(i), new WindowNode.Function(this.countFunction, ImmutableList.of(list2.get(i).toSymbolReference()), frame, false));
        }
        builder.put(symbol, new WindowNode.Function(this.rowNumberFunction, ImmutableList.of(), frame, false));
        return new WindowNode(this.idAllocator.getNextId(), unionNode, new WindowNode.Specification(list, Optional.empty()), builder.build(), Optional.empty(), ImmutableSet.of(), 0);
    }
}
