package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import io.airlift.units.DataSize;
import io.trino.SystemSessionProperties;
import io.trino.cost.CostCalculatorWithEstimatedExchanges;
import io.trino.cost.CostComparator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.analyzer.FeaturesConfig;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Unnest;
import io.trino.util.MorePredicates;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.class */
public class DetermineJoinDistributionType implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> {
        return joinNode.getDistributionType().isEmpty();
    });
    private final CostComparator costComparator;
    private final TaskCountEstimator taskCountEstimator;

    public DetermineJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator) {
        this.costComparator = (CostComparator) Objects.requireNonNull(costComparator, "costComparator is null");
        this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        FeaturesConfig.JoinDistributionType joinDistributionType = SystemSessionProperties.getJoinDistributionType(context.getSession());
        return joinDistributionType == FeaturesConfig.JoinDistributionType.AUTOMATIC ? Rule.Result.ofPlanNode(getCostBasedJoin(joinNode, context)) : Rule.Result.ofPlanNode(getSyntacticOrderJoin(joinNode, context, joinDistributionType));
    }

    public static boolean canReplicate(JoinNode joinNode, Rule.Context context) {
        if (!SystemSessionProperties.getJoinDistributionType(context.getSession()).canReplicate()) {
            return false;
        }
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        PlanNode right = joinNode.getRight();
        return context.getStatsProvider().getStats(right).getOutputSizeInBytes(right.getOutputSymbols(), context.getSymbolAllocator().getTypes()) <= ((double) joinMaxBroadcastTableSize.toBytes()) || getSourceTablesSizeInBytes(right, context) <= ((double) joinMaxBroadcastTableSize.toBytes());
    }

    public static double getSourceTablesSizeInBytes(PlanNode planNode, Rule.Context context) {
        return getSourceTablesSizeInBytes(planNode, context.getLookup(), context.getStatsProvider(), context.getSymbolAllocator().getTypes());
    }

    @VisibleForTesting
    static double getSourceTablesSizeInBytes(PlanNode planNode, Lookup lookup, StatsProvider statsProvider, TypeProvider typeProvider) {
        if (PlanNodeSearcher.searchFrom(planNode, lookup).where(MorePredicates.isInstanceOfAny(JoinNode.class, Unnest.class)).matches()) {
            return Double.NaN;
        }
        return PlanNodeSearcher.searchFrom(planNode, lookup).where(MorePredicates.isInstanceOfAny(TableScanNode.class, ValuesNode.class)).findAll().stream().mapToDouble(planNode2 -> {
            return statsProvider.getStats(planNode2).getOutputSizeInBytes(planNode2.getOutputSymbols(), typeProvider);
        }).sum();
    }

    private PlanNode getCostBasedJoin(JoinNode joinNode, Rule.Context context) {
        ArrayList arrayList = new ArrayList();
        addJoinsWithDifferentDistributions(joinNode, arrayList, context);
        addJoinsWithDifferentDistributions(joinNode.flipChildren(), arrayList, context);
        return (arrayList.stream().anyMatch(planNodeWithCost -> {
            return planNodeWithCost.getCost().hasUnknownComponents();
        }) || arrayList.isEmpty()) ? getSizeBasedJoin(joinNode, context) : ((PlanNodeWithCost) this.costComparator.forSession(context.getSession()).onResultOf((v0) -> {
            return v0.getCost();
        }).min(arrayList)).getPlanNode();
    }

    private JoinNode getSizeBasedJoin(JoinNode joinNode, Rule.Context context) {
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        boolean z = getSourceTablesSizeInBytes(joinNode.getRight(), context) <= ((double) joinMaxBroadcastTableSize.toBytes());
        if (z && !mustPartition(joinNode)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
        }
        boolean z2 = getSourceTablesSizeInBytes(joinNode.getLeft(), context) <= ((double) joinMaxBroadcastTableSize.toBytes());
        return (!z2 || mustPartition(joinNode.flipChildren())) ? z ? joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED) : z2 ? joinNode.flipChildren().withDistributionType(JoinNode.DistributionType.PARTITIONED) : getSyntacticOrderJoin(joinNode, context, FeaturesConfig.JoinDistributionType.AUTOMATIC) : joinNode.flipChildren().withDistributionType(JoinNode.DistributionType.REPLICATED);
    }

    private void addJoinsWithDifferentDistributions(JoinNode joinNode, List<PlanNodeWithCost> list, Rule.Context context) {
        if (!mustPartition(joinNode) && canReplicate(joinNode, context)) {
            list.add(getJoinNodeWithCost(context, joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED)));
        }
        if (mustReplicate(joinNode, context)) {
            return;
        }
        list.add(getJoinNodeWithCost(context, joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED)));
    }

    private JoinNode getSyntacticOrderJoin(JoinNode joinNode, Rule.Context context, FeaturesConfig.JoinDistributionType joinDistributionType) {
        if (mustPartition(joinNode)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        if (!mustReplicate(joinNode, context) && joinDistributionType.canPartition()) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
    }

    private static boolean mustPartition(JoinNode joinNode) {
        JoinNode.Type type = joinNode.getType();
        return type == JoinNode.Type.RIGHT || type == JoinNode.Type.FULL;
    }

    private static boolean mustReplicate(JoinNode joinNode, Rule.Context context) {
        JoinNode.Type type = joinNode.getType();
        if (joinNode.getCriteria().isEmpty() && (type == JoinNode.Type.INNER || type == JoinNode.Type.LEFT)) {
            return true;
        }
        return QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), context.getLookup());
    }

    private PlanNodeWithCost getJoinNodeWithCost(Rule.Context context, JoinNode joinNode) {
        return new PlanNodeWithCost(CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput(joinNode.getLeft(), joinNode.getRight(), context.getStatsProvider(), context.getSymbolAllocator().getTypes(), joinNode.getDistributionType().get() == JoinNode.DistributionType.REPLICATED, this.taskCountEstimator.estimateSourceDistributedTaskCount(context.getSession())).toPlanCost(), joinNode);
    }
}
