package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.warnings.WarningCollector;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.ToDoubleFunction;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/optimizations/DeterminePartitionCount.class */
public class DeterminePartitionCount implements PlanOptimizer {
    private static final Logger log = Logger.get(DeterminePartitionCount.class);
    private static final List<Class<? extends PlanNode>> INSERT_NODES = ImmutableList.of(TableExecuteNode.class, TableWriterNode.class, MergeWriterNode.class);
    private final StatsCalculator statsCalculator;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/optimizations/DeterminePartitionCount$Rewriter.class */
    public static class Rewriter extends SimplePlanRewriter<Void> {
        private final int partitionCount;

        private Rewriter(int i) {
            this.partitionCount = i;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            PartitioningHandle handle = exchangeNode.getPartitioningScheme().getPartitioning().getHandle();
            if (exchangeNode.getScope() != ExchangeNode.Scope.REMOTE || !(handle.getConnectorHandle() instanceof SystemPartitioningHandle)) {
                return exchangeNode;
            }
            Stream<PlanNode> stream = exchangeNode.getSources().stream();
            Objects.requireNonNull(rewriteContext);
            return new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), exchangeNode.getPartitioningScheme().withPartitionCount(this.partitionCount), (List) stream.map(rewriteContext::rewrite).collect(ImmutableList.toImmutableList()), exchangeNode.getInputs(), exchangeNode.getOrderingScheme());
        }
    }

    public DeterminePartitionCount(StatsCalculator statsCalculator) {
        this.statsCalculator = (StatsCalculator) Objects.requireNonNull(statsCalculator, "statsCalculator is null");
    }

    @Override // io.trino.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(tableStatsProvider, "tableStatsProvider is null");
        if (PlanNodeSearcher.searchFrom(planNode).whereIsInstanceOfAny(INSERT_NODES).matches() || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK) {
            return planNode;
        }
        try {
            return (PlanNode) determinePartitionCount(planNode, session, typeProvider, tableStatsProvider).map(num -> {
                return SimplePlanRewriter.rewriteWith(new Rewriter(num.intValue()), planNode);
            }).orElse(planNode);
        } catch (RuntimeException e) {
            log.warn(e, "Error occurred when determining hash partition count for query %s", new Object[]{session.getQueryId()});
            return planNode;
        }
    }

    private Optional<Integer> determinePartitionCount(PlanNode planNode, Session session, TypeProvider typeProvider, TableStatsProvider tableStatsProvider) {
        long bytes = SystemSessionProperties.getMinInputSizePerTask(session).toBytes();
        long minInputRowsPerTask = SystemSessionProperties.getMinInputRowsPerTask(session);
        if (bytes == 0 || minInputRowsPerTask == 0) {
            return Optional.empty();
        }
        if (isInputMultiplyingPlanNodePresent(planNode)) {
            return Optional.empty();
        }
        CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(this.statsCalculator, session, typeProvider, tableStatsProvider);
        Optional<Integer> partitionCountBasedOnOutputSize = getPartitionCountBasedOnOutputSize(planNode, cachingStatsProvider, typeProvider, bytes, SystemSessionProperties.getQueryMaxMemoryPerNode(session).toBytes());
        Optional<Integer> partitionCountBasedOnRows = getPartitionCountBasedOnRows(planNode, cachingStatsProvider, minInputRowsPerTask);
        if (partitionCountBasedOnOutputSize.isEmpty() || partitionCountBasedOnRows.isEmpty()) {
            return Optional.empty();
        }
        int max = Math.max(Math.max(partitionCountBasedOnOutputSize.get().intValue(), partitionCountBasedOnRows.get().intValue()), SystemSessionProperties.getMinHashPartitionCount(session));
        if (max >= SystemSessionProperties.getMaxHashPartitionCount(session)) {
            return Optional.empty();
        }
        log.debug("Estimated remote exchange partition count for query %s is %s", new Object[]{session.getQueryId(), Integer.valueOf(max)});
        return Optional.of(Integer.valueOf(max));
    }

    private static Optional<Integer> getPartitionCountBasedOnOutputSize(PlanNode planNode, StatsProvider statsProvider, TypeProvider typeProvider, long j, long j2) {
        double sourceNodesOutputStats = getSourceNodesOutputStats(planNode, planNode2 -> {
            return statsProvider.getStats(planNode2).getOutputSizeInBytes(planNode2.getOutputSymbols(), typeProvider);
        });
        double expandingNodesMaxOutputStats = getExpandingNodesMaxOutputStats(planNode, planNode3 -> {
            return statsProvider.getStats(planNode3).getOutputSizeInBytes(planNode3.getOutputSymbols(), typeProvider);
        });
        return (Double.isNaN(sourceNodesOutputStats) || Double.isNaN(expandingNodesMaxOutputStats)) ? Optional.empty() : Optional.of(Integer.valueOf(Math.max(getPartitionCount(Math.max(sourceNodesOutputStats, expandingNodesMaxOutputStats), j), (int) ((Math.max(sourceNodesOutputStats, expandingNodesMaxOutputStats) * 2.0d) / j2))));
    }

    private static Optional<Integer> getPartitionCountBasedOnRows(PlanNode planNode, StatsProvider statsProvider, long j) {
        double sourceNodesOutputStats = getSourceNodesOutputStats(planNode, planNode2 -> {
            return statsProvider.getStats(planNode2).getOutputRowCount();
        });
        double expandingNodesMaxOutputStats = getExpandingNodesMaxOutputStats(planNode, planNode3 -> {
            return statsProvider.getStats(planNode3).getOutputRowCount();
        });
        return (Double.isNaN(sourceNodesOutputStats) || Double.isNaN(expandingNodesMaxOutputStats)) ? Optional.empty() : Optional.of(Integer.valueOf(getPartitionCount(Math.max(sourceNodesOutputStats, expandingNodesMaxOutputStats), j)));
    }

    private static int getPartitionCount(double d, long j) {
        return Math.max((int) (d / j), 1);
    }

    private static boolean isInputMultiplyingPlanNodePresent(PlanNode planNode) {
        return PlanNodeSearcher.searchFrom(planNode).where(DeterminePartitionCount::isInputMultiplyingPlanNode).matches();
    }

    private static boolean isInputMultiplyingPlanNode(PlanNode planNode) {
        if (planNode instanceof UnnestNode) {
            return true;
        }
        if (!(planNode instanceof JoinNode)) {
            return false;
        }
        JoinNode joinNode = (JoinNode) planNode;
        return joinNode.isCrossJoin() ? (QueryCardinalityUtil.isAtMostScalar(joinNode.getRight()) || QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft())) ? false : true : joinNode.getCriteria().size() > 1;
    }

    private static double getExpandingNodesMaxOutputStats(PlanNode planNode, ToDoubleFunction<PlanNode> toDoubleFunction) {
        return PlanNodeSearcher.searchFrom(planNode).where(DeterminePartitionCount::isExpandingPlanNode).findAll().stream().mapToDouble(toDoubleFunction).max().orElse(0.0d);
    }

    private static boolean isExpandingPlanNode(PlanNode planNode) {
        return (planNode instanceof JoinNode) || (planNode instanceof UnionNode) || ((planNode instanceof ExchangeNode) && planNode.getSources().size() > 1);
    }

    private static double getSourceNodesOutputStats(PlanNode planNode, ToDoubleFunction<PlanNode> toDoubleFunction) {
        return PlanNodeSearcher.searchFrom(planNode).whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class).findAll().stream().mapToDouble(toDoubleFunction).sum();
    }
}
