package io.trino.cost;

import io.trino.Session;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.matching.Pattern;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.TopNNode;
import java.util.Optional;

/* loaded from: input_file:io/trino/cost/TopNStatsRule.class */
public class TopNStatsRule extends SimpleStatsRule<TopNNode> {
    private static final Pattern<TopNNode> PATTERN = Patterns.topN();

    public TopNStatsRule(StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
    }

    @Override // io.trino.cost.ComposableStatsCalculator.Rule
    public Pattern<TopNNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.trino.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(TopNNode topNNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider, TableStatsProvider tableStatsProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(topNNode.getSource());
        double outputRowCount = stats.getOutputRowCount();
        if (topNNode.getStep() != TopNNode.Step.SINGLE) {
            return Optional.empty();
        }
        if (outputRowCount <= topNNode.getCount()) {
            return Optional.of(stats);
        }
        long count = topNNode.getCount();
        PlanNodeStatsEstimate build = PlanNodeStatsEstimate.buildFrom(stats).setOutputRowCount(count).build();
        if (count == 0) {
            return Optional.of(build);
        }
        Symbol symbol = topNNode.getOrderingScheme().getOrderBy().get(0);
        SortOrder ordering = topNNode.getOrderingScheme().getOrdering(symbol);
        return Optional.of(build.mapSymbolColumnStatistics(symbol, symbolStatsEstimate -> {
            SymbolStatsEstimate.Builder buildFrom = SymbolStatsEstimate.buildFrom(symbolStatsEstimate);
            double nullsFraction = outputRowCount * symbolStatsEstimate.getNullsFraction();
            if (!ordering.isNullsFirst()) {
                double d = outputRowCount - nullsFraction;
                if (d > count) {
                    buildFrom.setNullsFraction(0.0d);
                } else {
                    buildFrom.setNullsFraction((count - d) / count);
                }
            } else if (nullsFraction > count) {
                buildFrom.setNullsFraction(1.0d);
            } else {
                buildFrom.setNullsFraction(nullsFraction / count);
            }
            return buildFrom.build();
        }));
    }
}
