package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.ScalarStatsCalculator;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.PartialTranslator;
import io.trino.sql.planner.ReferenceAwareExpressionNodeInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.class */
public class PushProjectionIntoTableScan implements Rule<ProjectNode> {
    private static final Capture<TableScanNode> TABLE_SCAN = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().with(Patterns.source().matching(Patterns.tableScan().capturedAs(TABLE_SCAN)));
    private final Metadata metadata;
    private final TypeAnalyzer typeAnalyzer;
    private final ScalarStatsCalculator scalarStatsCalculator;

    public PushProjectionIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer, ScalarStatsCalculator scalarStatsCalculator) {
        this.metadata = metadata;
        this.typeAnalyzer = typeAnalyzer;
        this.scalarStatsCalculator = (ScalarStatsCalculator) Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAllowPushdownIntoConnectors(session);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        TableScanNode tableScanNode = (TableScanNode) captures.get(TABLE_SCAN);
        Map map = (Map) projectNode.getAssignments().getMap().entrySet().stream().flatMap(entry -> {
            return PartialTranslator.extractPartialTranslations((Expression) entry.getValue(), context.getSession(), this.typeAnalyzer, context.getSymbolAllocator().getTypes()).entrySet().stream();
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }, (connectorExpression, connectorExpression2) -> {
            return connectorExpression;
        }));
        ImmutableList copyOf = ImmutableList.copyOf(map.keySet());
        List<ConnectorExpression> copyOf2 = ImmutableList.copyOf(map.values());
        Map map2 = (Map) tableScanNode.getAssignments().keySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getName();
        }, Function.identity()));
        Optional<ProjectionApplicationResult<TableHandle>> applyProjection = this.metadata.applyProjection(context.getSession(), tableScanNode.getTable(), copyOf2, (Map) map2.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry2 -> {
            return tableScanNode.getAssignments().get(entry2.getValue());
        })));
        if (applyProjection.isEmpty()) {
            return Rule.Result.empty();
        }
        List projections = applyProjection.get().getProjections();
        Preconditions.checkState(projections.size() == copyOf2.size(), "Mismatch between input and output projections from the connector: expected %s but got %s", copyOf2.size(), projections.size());
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Assignment assignment : applyProjection.get().getAssignments()) {
            Symbol newSymbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
            arrayList.add(newSymbol);
            hashMap.put(newSymbol, assignment.getColumn());
            hashMap2.put(assignment.getVariable(), newSymbol);
        }
        List list = (List) projections.stream().map(connectorExpression3 -> {
            return ConnectorExpressionTranslator.translate(connectorExpression3, hashMap2, new LiteralEncoder(this.metadata));
        }).collect(ImmutableList.toImmutableList());
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i = 0; i < copyOf.size(); i++) {
            builder.put((NodeRef) copyOf.get(i), (Expression) list.get(i));
        }
        ImmutableMap build = builder.build();
        Assignments.Builder builder2 = Assignments.builder();
        projectNode.getAssignments().entrySet().forEach(entry3 -> {
            builder2.put((Symbol) entry3.getKey(), ReferenceAwareExpressionNodeInliner.replaceExpression((Expression) entry3.getValue(), build));
        });
        Optional<U> map3 = tableScanNode.getStatistics().map(planNodeStatsEstimate -> {
            PlanNodeStatsEstimate.Builder builder3 = PlanNodeStatsEstimate.builder();
            builder3.setOutputRowCount(planNodeStatsEstimate.getOutputRowCount());
            for (int i2 = 0; i2 < copyOf2.size(); i2++) {
                ConnectorExpression connectorExpression4 = (ConnectorExpression) copyOf2.get(i2);
                Variable variable = (ConnectorExpression) projections.get(i2);
                if (variable instanceof Variable) {
                    builder3.addSymbolStatistics((Symbol) hashMap2.get(variable.getName()), this.scalarStatsCalculator.calculate(ConnectorExpressionTranslator.translate(connectorExpression4, map2, new LiteralEncoder(this.metadata)), planNodeStatsEstimate, context.getSession(), context.getSymbolAllocator().getTypes()));
                }
            }
            return builder3.build();
        });
        verifyTablePartitioning(context, tableScanNode, (TableHandle) applyProjection.get().getHandle());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(tableScanNode.getId(), (TableHandle) applyProjection.get().getHandle(), arrayList, hashMap, TupleDomain.all(), map3, tableScanNode.isUpdateTarget(), tableScanNode.getUseConnectorNodePartitioning()), builder2.build()));
    }

    private void verifyTablePartitioning(Rule.Context context, TableScanNode tableScanNode, TableHandle tableHandle) {
        if (tableScanNode.getUseConnectorNodePartitioning().isEmpty()) {
            return;
        }
        Verify.verify(this.metadata.getTableProperties(context.getSession(), tableHandle).getTablePartitioning().equals(this.metadata.getTableProperties(context.getSession(), tableScanNode.getTable()).getTablePartitioning()), "Partitioning must not change after projection is pushed down", new Object[0]);
    }
}
