package io.trino.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableExecuteHandle;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BeginTableExecuteResult;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.DeleteNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UpdateNode;
import io.trino.sql.planner.planprinter.PlanPrinter;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/* loaded from: input_file:io/trino/sql/planner/optimizations/BeginTableWrite.class */
public class BeginTableWrite implements PlanOptimizer {
    private final Metadata metadata;

    /* loaded from: input_file:io/trino/sql/planner/optimizations/BeginTableWrite$Rewriter.class */
    private class Rewriter extends SimplePlanRewriter<Optional<TableWriterNode.WriterTarget>> {
        private final Session session;

        public Rewriter(Session session) {
            this.session = session;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitTableWriter(TableWriterNode tableWriterNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
            return new TableWriterNode(tableWriterNode.getId(), rewriteContext.rewrite(tableWriterNode.getSource(), rewriteContext.get()), BeginTableWrite.getContextTarget(rewriteContext), tableWriterNode.getRowCountSymbol(), tableWriterNode.getFragmentSymbol(), tableWriterNode.getColumns(), tableWriterNode.getColumnNames(), tableWriterNode.getNotNullColumnSymbols(), tableWriterNode.getPartitioningScheme(), tableWriterNode.getPreferredPartitioningScheme(), tableWriterNode.getStatisticsAggregation(), tableWriterNode.getStatisticsAggregationDescriptor());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitDelete(DeleteNode deleteNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
            TableWriterNode.DeleteTarget deleteTarget = (TableWriterNode.DeleteTarget) BeginTableWrite.getContextTarget(rewriteContext);
            return new DeleteNode(deleteNode.getId(), rewriteModifyTableScan(deleteNode.getSource(), deleteTarget.getHandleOrElseThrow()), deleteTarget, deleteNode.getRowId(), deleteNode.getOutputSymbols());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitUpdate(UpdateNode updateNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
            TableWriterNode.UpdateTarget updateTarget = (TableWriterNode.UpdateTarget) BeginTableWrite.getContextTarget(rewriteContext);
            return new UpdateNode(updateNode.getId(), rewriteModifyTableScan(updateNode.getSource(), updateTarget.getHandleOrElseThrow()), updateTarget, updateNode.getRowId(), updateNode.getColumnValueAndRowIdSymbols(), updateNode.getOutputSymbols());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitTableExecute(TableExecuteNode tableExecuteNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
            TableWriterNode.TableExecuteTarget tableExecuteTarget = (TableWriterNode.TableExecuteTarget) BeginTableWrite.getContextTarget(rewriteContext);
            return new TableExecuteNode(tableExecuteNode.getId(), rewriteModifyTableScan(tableExecuteNode.getSource(), tableExecuteTarget.getSourceHandle().orElseThrow()), tableExecuteTarget, tableExecuteNode.getRowCountSymbol(), tableExecuteNode.getFragmentSymbol(), tableExecuteNode.getColumns(), tableExecuteNode.getColumnNames(), tableExecuteNode.getPartitioningScheme(), tableExecuteNode.getPreferredPartitioningScheme());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode statisticsWriterNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
            return new StatisticsWriterNode(statisticsWriterNode.getId(), rewriteContext.rewrite(statisticsWriterNode.getSource(), rewriteContext.get()), new StatisticsWriterNode.WriteStatisticsHandle(BeginTableWrite.this.metadata.beginStatisticsCollection(this.session, ((StatisticsWriterNode.WriteStatisticsReference) statisticsWriterNode.getTarget()).getHandle())), statisticsWriterNode.getRowCountSymbol(), statisticsWriterNode.isRowCountEnabled(), statisticsWriterNode.getDescriptor());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitTableFinish(TableFinishNode tableFinishNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
            PlanNode source = tableFinishNode.getSource();
            TableWriterNode.WriterTarget createWriterTarget = createWriterTarget(getWriterTarget(source));
            return new TableFinishNode(tableFinishNode.getId(), rewriteContext.rewrite(source, Optional.of(createWriterTarget)), createWriterTarget, tableFinishNode.getRowCountSymbol(), tableFinishNode.getStatisticsAggregation(), tableFinishNode.getStatisticsAggregationDescriptor());
        }

        public TableWriterNode.WriterTarget getWriterTarget(PlanNode planNode) {
            if (planNode instanceof TableWriterNode) {
                return ((TableWriterNode) planNode).getTarget();
            }
            if (planNode instanceof DeleteNode) {
                DeleteNode deleteNode = (DeleteNode) planNode;
                return new TableWriterNode.DeleteTarget(Optional.of(findTableScanHandleForDeleteOrUpdate(deleteNode.getSource())), deleteNode.getTarget().getSchemaTableName());
            }
            if (planNode instanceof UpdateNode) {
                UpdateNode updateNode = (UpdateNode) planNode;
                TableWriterNode.UpdateTarget target = updateNode.getTarget();
                return new TableWriterNode.UpdateTarget(Optional.of(findTableScanHandleForDeleteOrUpdate(updateNode.getSource())), target.getSchemaTableName(), target.getUpdatedColumns(), target.getUpdatedColumnHandles());
            }
            if (planNode instanceof TableExecuteNode) {
                TableWriterNode.TableExecuteTarget target2 = ((TableExecuteNode) planNode).getTarget();
                return new TableWriterNode.TableExecuteTarget(target2.getExecuteHandle(), findTableScanHandleForTableExecute(((TableExecuteNode) planNode).getSource()), target2.getSchemaTableName());
            }
            if ((planNode instanceof ExchangeNode) || (planNode instanceof UnionNode)) {
                return (TableWriterNode.WriterTarget) Iterables.getOnlyElement((Set) planNode.getSources().stream().map(this::getWriterTarget).collect(Collectors.toSet()));
            }
            throw new IllegalArgumentException("Invalid child for TableCommitNode: " + planNode.getClass().getSimpleName());
        }

        private TableWriterNode.WriterTarget createWriterTarget(TableWriterNode.WriterTarget writerTarget) {
            if (writerTarget instanceof TableWriterNode.CreateReference) {
                TableWriterNode.CreateReference createReference = (TableWriterNode.CreateReference) writerTarget;
                return new TableWriterNode.CreateTarget(BeginTableWrite.this.metadata.beginCreateTable(this.session, createReference.getCatalog(), createReference.getTableMetadata(), createReference.getLayout()), createReference.getTableMetadata().getTable());
            }
            if (writerTarget instanceof TableWriterNode.InsertReference) {
                TableWriterNode.InsertReference insertReference = (TableWriterNode.InsertReference) writerTarget;
                return new TableWriterNode.InsertTarget(BeginTableWrite.this.metadata.beginInsert(this.session, insertReference.getHandle(), insertReference.getColumns()), BeginTableWrite.this.metadata.getTableMetadata(this.session, insertReference.getHandle()).getTable());
            }
            if (writerTarget instanceof TableWriterNode.DeleteTarget) {
                TableWriterNode.DeleteTarget deleteTarget = (TableWriterNode.DeleteTarget) writerTarget;
                return new TableWriterNode.DeleteTarget(Optional.of(BeginTableWrite.this.metadata.beginDelete(this.session, deleteTarget.getHandleOrElseThrow())), deleteTarget.getSchemaTableName());
            }
            if (writerTarget instanceof TableWriterNode.UpdateTarget) {
                TableWriterNode.UpdateTarget updateTarget = (TableWriterNode.UpdateTarget) writerTarget;
                return new TableWriterNode.UpdateTarget(Optional.of(BeginTableWrite.this.metadata.beginUpdate(this.session, updateTarget.getHandleOrElseThrow(), updateTarget.getUpdatedColumnHandles())), updateTarget.getSchemaTableName(), updateTarget.getUpdatedColumns(), updateTarget.getUpdatedColumnHandles());
            }
            if (writerTarget instanceof TableWriterNode.RefreshMaterializedViewReference) {
                TableWriterNode.RefreshMaterializedViewReference refreshMaterializedViewReference = (TableWriterNode.RefreshMaterializedViewReference) writerTarget;
                return new TableWriterNode.RefreshMaterializedViewTarget(refreshMaterializedViewReference.getStorageTableHandle(), BeginTableWrite.this.metadata.beginRefreshMaterializedView(this.session, refreshMaterializedViewReference.getStorageTableHandle(), refreshMaterializedViewReference.getSourceTableHandles()), BeginTableWrite.this.metadata.getTableMetadata(this.session, refreshMaterializedViewReference.getStorageTableHandle()).getTable(), refreshMaterializedViewReference.getSourceTableHandles());
            }
            if (!(writerTarget instanceof TableWriterNode.TableExecuteTarget)) {
                throw new IllegalArgumentException("Unhandled target type: " + writerTarget.getClass().getSimpleName());
            }
            TableWriterNode.TableExecuteTarget tableExecuteTarget = (TableWriterNode.TableExecuteTarget) writerTarget;
            BeginTableExecuteResult<TableExecuteHandle, TableHandle> beginTableExecute = BeginTableWrite.this.metadata.beginTableExecute(this.session, tableExecuteTarget.getExecuteHandle(), tableExecuteTarget.getMandatorySourceHandle());
            return new TableWriterNode.TableExecuteTarget((TableExecuteHandle) beginTableExecute.getTableExecuteHandle(), Optional.of((TableHandle) beginTableExecute.getSourceHandle()), tableExecuteTarget.getSchemaTableName());
        }

        private TableHandle findTableScanHandleForDeleteOrUpdate(PlanNode planNode) {
            if (planNode instanceof TableScanNode) {
                Preconditions.checkArgument(((TableScanNode) planNode).isUpdateTarget(), "TableScanNode should be an updatable target");
                return ((TableScanNode) planNode).getTable();
            }
            if (planNode instanceof FilterNode) {
                return findTableScanHandleForDeleteOrUpdate(((FilterNode) planNode).getSource());
            }
            if (planNode instanceof ProjectNode) {
                return findTableScanHandleForDeleteOrUpdate(((ProjectNode) planNode).getSource());
            }
            if (planNode instanceof SemiJoinNode) {
                return findTableScanHandleForDeleteOrUpdate(((SemiJoinNode) planNode).getSource());
            }
            if (planNode instanceof JoinNode) {
                return findTableScanHandleForDeleteOrUpdate(((JoinNode) planNode).getLeft());
            }
            if (planNode instanceof AssignUniqueId) {
                return findTableScanHandleForDeleteOrUpdate(((AssignUniqueId) planNode).getSource());
            }
            if (planNode instanceof MarkDistinctNode) {
                return findTableScanHandleForDeleteOrUpdate(((MarkDistinctNode) planNode).getSource());
            }
            throw new IllegalArgumentException("Invalid descendant for DeleteNode or UpdateNode: " + planNode.getClass().getName());
        }

        private Optional<TableHandle> findTableScanHandleForTableExecute(PlanNode planNode) {
            List findAll = PlanNodeSearcher.searchFrom(planNode).where(planNode2 -> {
                return (planNode2 instanceof TableScanNode) && ((TableScanNode) planNode2).isUpdateTarget();
            }).findAll();
            if (findAll.size() == 1) {
                return Optional.of(((TableScanNode) findAll.get(0)).getTable());
            }
            throw new IllegalArgumentException("Expected to find exactly one update target TableScanNode in plan but found: " + findAll);
        }

        private PlanNode rewriteModifyTableScan(PlanNode planNode, final TableHandle tableHandle) {
            final AtomicInteger atomicInteger = new AtomicInteger(0);
            PlanNode rewriteWith = SimplePlanRewriter.rewriteWith(new SimplePlanRewriter<Void>() { // from class: io.trino.sql.planner.optimizations.BeginTableWrite.Rewriter.1
                @Override // io.trino.sql.planner.plan.PlanVisitor
                public PlanNode visitTableScan(TableScanNode tableScanNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
                    if (!tableScanNode.isUpdateTarget()) {
                        return tableScanNode;
                    }
                    atomicInteger.incrementAndGet();
                    return new TableScanNode(tableScanNode.getId(), tableHandle, tableScanNode.getOutputSymbols(), tableScanNode.getAssignments(), tableScanNode.getEnforcedConstraint(), tableScanNode.getStatistics(), tableScanNode.isUpdateTarget(), tableScanNode.getUseConnectorNodePartitioning());
                }
            }, planNode, null);
            Verify.verify(atomicInteger.get() == 1, "Expected to find exactly one update target TableScanNode but found %s", atomicInteger);
            return rewriteWith;
        }
    }

    public BeginTableWrite(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override // io.trino.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        try {
            return SimplePlanRewriter.rewriteWith(new Rewriter(session), planNode, Optional.empty());
        } catch (RuntimeException e) {
            try {
                e.addSuppressed(new Exception("Current plan:\n" + PlanPrinter.textLogicalPlan(planNode, typeProvider, this.metadata, StatsAndCosts.empty(), session, 4, false)));
            } catch (RuntimeException e2) {
            }
            throw e;
        }
    }

    private static TableWriterNode.WriterTarget getContextTarget(SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> rewriteContext) {
        return rewriteContext.get().orElseThrow(() -> {
            return new IllegalStateException("WriterTarget not present");
        });
    }
}
