package io.trino.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Sets;
import com.google.common.collect.UnmodifiableIterator;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.MoreFutures;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.connector.CatalogHandle;
import io.trino.execution.ForQueryExecution;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.TableExecuteContextManager;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.metadata.Split;
import io.trino.spi.HostAddress;
import io.trino.spi.QueryId;
import io.trino.spi.SplitWeight;
import io.trino.spi.exchange.ExchangeSourceHandle;
import io.trino.split.SplitSource;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SplitSourceFactory;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.TableWriterNode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import javax.annotation.concurrent.GuardedBy;
import javax.inject.Inject;

/* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory.class */
public class StageTaskSourceFactory implements TaskSourceFactory {
    private static final Logger log = Logger.get(StageTaskSourceFactory.class);
    private final SplitSourceFactory splitSourceFactory;
    private final TableExecuteContextManager tableExecuteContextManager;
    private final int splitBatchSize;
    private final Executor executor;
    private final InternalNodeManager nodeManager;

    /* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory$ArbitraryDistributionTaskSource.class */
    public static class ArbitraryDistributionTaskSource implements TaskSource {
        private final Multimap<PlanNodeId, ExchangeSourceHandle> partitionedExchangeSourceHandles;
        private final Multimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles;
        private final long targetPartitionSizeInBytes;
        private final DataSize taskMemory;
        private boolean finished;

        public static ArbitraryDistributionTaskSource create(Session session, PlanFragment planFragment, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap, DataSize dataSize) {
            Preconditions.checkArgument(planFragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", planFragment.getPartitionedSources());
            return new ArbitraryDistributionTaskSource(StageTaskSourceFactory.getPartitionedExchangeSourceHandles(planFragment, multimap), StageTaskSourceFactory.getReplicatedExchangeSourceHandles(planFragment, multimap), dataSize, SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory(session));
        }

        @VisibleForTesting
        ArbitraryDistributionTaskSource(Multimap<PlanNodeId, ExchangeSourceHandle> multimap, Multimap<PlanNodeId, ExchangeSourceHandle> multimap2, DataSize dataSize, DataSize dataSize2) {
            this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf((Multimap) Objects.requireNonNull(multimap, "partitionedExchangeSourceHandles is null"));
            this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf((Multimap) Objects.requireNonNull(multimap2, "replicatedExchangeSourceHandles is null"));
            this.taskMemory = (DataSize) Objects.requireNonNull(dataSize2, "taskMemory is null");
            this.targetPartitionSizeInBytes = dataSize.toBytes();
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public ListenableFuture<List<TaskDescriptor>> getMoreTasks() {
            if (this.finished) {
                return Futures.immediateFuture(ImmutableList.of());
            }
            NodeRequirements nodeRequirements = new NodeRequirements(Optional.empty(), ImmutableSet.of(), this.taskMemory);
            ImmutableList.Builder builder = ImmutableList.builder();
            int i = 0;
            ImmutableListMultimap.Builder builder2 = ImmutableListMultimap.builder();
            long j = 0;
            int i2 = 0;
            for (Map.Entry entry : this.partitionedExchangeSourceHandles.entries()) {
                PlanNodeId planNodeId = (PlanNodeId) entry.getKey();
                ExchangeSourceHandle exchangeSourceHandle = (ExchangeSourceHandle) entry.getValue();
                long dataSizeInBytes = exchangeSourceHandle.getDataSizeInBytes();
                if (j != 0 && j + dataSizeInBytes > this.targetPartitionSizeInBytes) {
                    builder2.putAll(this.replicatedExchangeSourceHandles);
                    int i3 = i;
                    i++;
                    builder.add(new TaskDescriptor(i3, ImmutableListMultimap.of(), builder2.build(), nodeRequirements));
                    builder2 = ImmutableListMultimap.builder();
                    j = 0;
                    i2 = 0;
                }
                builder2.put(planNodeId, exchangeSourceHandle);
                j += dataSizeInBytes;
                i2++;
            }
            if (i2 > 0) {
                builder2.putAll(this.replicatedExchangeSourceHandles);
                builder.add(new TaskDescriptor(i, ImmutableListMultimap.of(), builder2.build(), nodeRequirements));
            }
            this.finished = true;
            return Futures.immediateFuture(builder.build());
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public boolean isFinished() {
            return this.finished;
        }

        @Override // io.trino.execution.scheduler.TaskSource, java.io.Closeable, java.lang.AutoCloseable
        public void close() {
        }
    }

    /* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory$HashDistributionTaskSource.class */
    public static class HashDistributionTaskSource implements TaskSource {
        private final Map<PlanNodeId, SplitSource> splitSources;
        private final Multimap<PlanNodeId, ExchangeSourceHandle> partitionedExchangeSourceHandles;
        private final Multimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles;
        private final int splitBatchSize;
        private final LongConsumer getSplitTimeRecorder;
        private final int[] bucketToPartitionMap;
        private final Optional<BucketNodeMap> bucketNodeMap;
        private final DataSize taskMemory;
        private final Optional<CatalogHandle> catalogRequirement;
        private final long targetPartitionSourceSizeInBytes;
        private final long targetPartitionSplitWeight;
        private final Executor executor;

        @GuardedBy("this")
        private ListenableFuture<List<LoadedSplits>> loadedSplitsFuture;

        @GuardedBy("this")
        private boolean finished;

        @GuardedBy("this")
        private boolean closed;

        public static HashDistributionTaskSource create(Session session, PlanFragment planFragment, SplitSourceFactory splitSourceFactory, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap, int i, LongConsumer longConsumer, int[] iArr, Optional<BucketNodeMap> optional, long j, DataSize dataSize, boolean z, Executor executor) {
            Preconditions.checkArgument(optional.isPresent() || planFragment.getPartitionedSources().isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)");
            return new HashDistributionTaskSource(splitSourceFactory.createSplitSources(session, planFragment), StageTaskSourceFactory.getPartitionedExchangeSourceHandles(planFragment, multimap), StageTaskSourceFactory.getReplicatedExchangeSourceHandles(planFragment, multimap), i, longConsumer, iArr, optional, planFragment.getPartitioning().getCatalogHandle(), j, (z && isWriteFragment(planFragment)) ? DataSize.of(0L, DataSize.Unit.BYTE) : dataSize, SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory(session), executor);
        }

        private static boolean isWriteFragment(PlanFragment planFragment) {
            return ((Boolean) planFragment.getRoot().accept(new PlanVisitor<Boolean, Void>() { // from class: io.trino.execution.scheduler.StageTaskSourceFactory.HashDistributionTaskSource.1
                /* JADX INFO: Access modifiers changed from: protected */
                @Override // io.trino.sql.planner.plan.PlanVisitor
                public Boolean visitPlan(PlanNode planNode, Void r6) {
                    Iterator<PlanNode> it = planNode.getSources().iterator();
                    while (it.hasNext()) {
                        if (((Boolean) it.next().accept(this, r6)).booleanValue()) {
                            return true;
                        }
                    }
                    return false;
                }

                @Override // io.trino.sql.planner.plan.PlanVisitor
                public Boolean visitTableWriter(TableWriterNode tableWriterNode, Void r4) {
                    return true;
                }
            }, null)).booleanValue();
        }

        @VisibleForTesting
        HashDistributionTaskSource(Map<PlanNodeId, SplitSource> map, Multimap<PlanNodeId, ExchangeSourceHandle> multimap, Multimap<PlanNodeId, ExchangeSourceHandle> multimap2, int i, LongConsumer longConsumer, int[] iArr, Optional<BucketNodeMap> optional, Optional<CatalogHandle> optional2, long j, DataSize dataSize, DataSize dataSize2, Executor executor) {
            this.splitSources = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "splitSources is null"));
            this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf((Multimap) Objects.requireNonNull(multimap, "partitionedExchangeSourceHandles is null"));
            this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf((Multimap) Objects.requireNonNull(multimap2, "replicatedExchangeSourceHandles is null"));
            this.splitBatchSize = i;
            this.getSplitTimeRecorder = (LongConsumer) Objects.requireNonNull(longConsumer, "getSplitTimeRecorder is null");
            this.bucketToPartitionMap = (int[]) Objects.requireNonNull(iArr, "bucketToPartitionMap is null");
            this.bucketNodeMap = (Optional) Objects.requireNonNull(optional, "bucketNodeMap is null");
            this.taskMemory = (DataSize) Objects.requireNonNull(dataSize2, "taskMemory is null");
            Preconditions.checkArgument(optional.isPresent() || map.isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)");
            this.catalogRequirement = (Optional) Objects.requireNonNull(optional2, "catalogRequirement is null");
            this.targetPartitionSourceSizeInBytes = dataSize.toBytes();
            this.targetPartitionSplitWeight = j;
            this.executor = (Executor) Objects.requireNonNull(executor, "executor is null");
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public synchronized ListenableFuture<List<TaskDescriptor>> getMoreTasks() {
            if (this.finished || this.closed) {
                return Futures.immediateFuture(ImmutableList.of());
            }
            Preconditions.checkState(this.loadedSplitsFuture == null, "getMoreTasks called again while splits are being loaded");
            this.loadedSplitsFuture = Futures.allAsList((List) this.splitSources.entrySet().stream().map(entry -> {
                SplitLoadingFuture splitLoadingFuture = new SplitLoadingFuture((PlanNodeId) entry.getKey(), (SplitSource) entry.getValue(), this.splitBatchSize, this.getSplitTimeRecorder, this.executor);
                splitLoadingFuture.load();
                return splitLoadingFuture;
            }).collect(ImmutableList.toImmutableList()));
            return Futures.transform(this.loadedSplitsFuture, list -> {
                List<TaskDescriptor> postprocessTasks;
                synchronized (this) {
                    HashMap hashMap = new HashMap();
                    HashMultimap create = HashMultimap.create();
                    Iterator it = list.iterator();
                    while (it.hasNext()) {
                        LoadedSplits loadedSplits = (LoadedSplits) it.next();
                        BucketNodeMap orElseThrow = this.bucketNodeMap.orElseThrow(() -> {
                            return new VerifyException("bucket to node map is expected to be present");
                        });
                        for (Split split : loadedSplits.getSplits()) {
                            int partitionForBucket = getPartitionForBucket(orElseThrow.getBucket(split));
                            HostAddress hostAndPort = orElseThrow.getAssignedNode(split).getHostAndPort();
                            Set set = create.get(Integer.valueOf(partitionForBucket));
                            if (set.isEmpty()) {
                                set.add(hostAndPort);
                            } else {
                                Preconditions.checkState(set.contains(hostAndPort), "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", Integer.valueOf(partitionForBucket), set, hostAndPort);
                                set.removeIf(hostAddress -> {
                                    return !hostAddress.equals(hostAndPort);
                                });
                            }
                            if (!split.isRemotelyAccessible()) {
                                ImmutableSet copyOf = ImmutableSet.copyOf(split.getAddresses());
                                Verify.verify(!copyOf.isEmpty(), "split is not remotely accessible but the list of addresses is empty: %s", split);
                                Set set2 = create.get(Integer.valueOf(partitionForBucket));
                                if (set2.isEmpty()) {
                                    set2.addAll(copyOf);
                                } else {
                                    Sets.SetView intersection = Sets.intersection(copyOf, set2);
                                    Preconditions.checkState(!intersection.isEmpty(), "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", Integer.valueOf(partitionForBucket), set2, copyOf);
                                    create.replaceValues(Integer.valueOf(partitionForBucket), ImmutableSet.copyOf(intersection));
                                }
                            }
                            ((Multimap) hashMap.computeIfAbsent(Integer.valueOf(partitionForBucket), num -> {
                                return ArrayListMultimap.create();
                            })).put(loadedSplits.getPlanNodeId(), split);
                        }
                    }
                    HashMap hashMap2 = new HashMap();
                    for (Map.Entry entry2 : this.partitionedExchangeSourceHandles.entries()) {
                        PlanNodeId planNodeId = (PlanNodeId) entry2.getKey();
                        ExchangeSourceHandle exchangeSourceHandle = (ExchangeSourceHandle) entry2.getValue();
                        ((Multimap) hashMap2.computeIfAbsent(Integer.valueOf(exchangeSourceHandle.getPartitionId()), num2 -> {
                            return ArrayListMultimap.create();
                        })).put(planNodeId, exchangeSourceHandle);
                    }
                    int i = 0;
                    ImmutableList.Builder builder = ImmutableList.builder();
                    UnmodifiableIterator it2 = Sets.union(hashMap.keySet(), hashMap2.keySet()).iterator();
                    while (it2.hasNext()) {
                        Integer num3 = (Integer) it2.next();
                        int i2 = i;
                        i++;
                        builder.add(new TaskDescriptor(i2, (ListMultimap) hashMap.getOrDefault(num3, ImmutableListMultimap.of()), ImmutableListMultimap.builder().putAll((Multimap) hashMap2.getOrDefault(num3, ImmutableMultimap.of())).build(), new NodeRequirements(this.catalogRequirement, create.get(num3), this.taskMemory)));
                    }
                    postprocessTasks = postprocessTasks(builder.build());
                    this.finished = true;
                }
                return postprocessTasks;
            }, this.executor);
        }

        private List<TaskDescriptor> postprocessTasks(List<TaskDescriptor> list) {
            ListMultimap<NodeRequirements, TaskDescriptor> groupCompatibleTasks = groupCompatibleTasks(list);
            ImmutableList.Builder builder = ImmutableList.builder();
            long sum = this.replicatedExchangeSourceHandles.values().stream().mapToLong((v0) -> {
                return v0.getDataSizeInBytes();
            }).sum();
            int i = 0;
            for (Map.Entry entry : groupCompatibleTasks.asMap().entrySet()) {
                NodeRequirements nodeRequirements = (NodeRequirements) entry.getKey();
                Collection<TaskDescriptor> collection = (Collection) entry.getValue();
                ImmutableListMultimap.Builder builder2 = ImmutableListMultimap.builder();
                ImmutableListMultimap.Builder builder3 = ImmutableListMultimap.builder();
                long j = 0;
                long j2 = 0;
                for (TaskDescriptor taskDescriptor : collection) {
                    ListMultimap<PlanNodeId, Split> splits = taskDescriptor.getSplits();
                    ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles = taskDescriptor.getExchangeSourceHandles();
                    long sum2 = splits.values().stream().mapToLong(split -> {
                        return split.getSplitWeight().getRawValue();
                    }).sum();
                    long sum3 = exchangeSourceHandles.values().stream().mapToLong((v0) -> {
                        return v0.getDataSizeInBytes();
                    }).sum();
                    if ((j > 0 || j2 > 0) && (j + sum2 > this.targetPartitionSplitWeight || j2 + sum3 + sum > this.targetPartitionSourceSizeInBytes)) {
                        builder3.putAll(this.replicatedExchangeSourceHandles);
                        int i2 = i;
                        i++;
                        builder.add(new TaskDescriptor(i2, builder2.build(), builder3.build(), nodeRequirements));
                        builder2 = ImmutableListMultimap.builder();
                        builder3 = ImmutableListMultimap.builder();
                        j = 0;
                        j2 = 0;
                    }
                    builder2.putAll(splits);
                    builder3.putAll(exchangeSourceHandles);
                    j += sum2;
                    j2 += sum3;
                }
                ImmutableListMultimap build = builder2.build();
                ImmutableListMultimap build2 = builder3.build();
                if (!build.isEmpty() || !build2.isEmpty()) {
                    int i3 = i;
                    i++;
                    builder.add(new TaskDescriptor(i3, build, ImmutableListMultimap.builder().putAll(build2).putAll(this.replicatedExchangeSourceHandles).build(), nodeRequirements));
                }
            }
            return builder.build();
        }

        private ListMultimap<NodeRequirements, TaskDescriptor> groupCompatibleTasks(List<TaskDescriptor> list) {
            return Multimaps.index(list, (v0) -> {
                return v0.getNodeRequirements();
            });
        }

        private int getPartitionForBucket(int i) {
            return this.bucketToPartitionMap[i];
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public synchronized boolean isFinished() {
            return this.finished;
        }

        @Override // io.trino.execution.scheduler.TaskSource, java.io.Closeable, java.lang.AutoCloseable
        public synchronized void close() {
            if (this.closed) {
                return;
            }
            this.closed = true;
            Iterator<SplitSource> it = this.splitSources.values().iterator();
            while (it.hasNext()) {
                try {
                    it.next().close();
                } catch (RuntimeException e) {
                    StageTaskSourceFactory.log.error(e, "Error closing split source");
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory$LoadedSplits.class */
    public static class LoadedSplits {
        private final PlanNodeId planNodeId;
        private final List<Split> splits;

        private LoadedSplits(PlanNodeId planNodeId, List<Split> list) {
            this.planNodeId = (PlanNodeId) Objects.requireNonNull(planNodeId, "planNodeId is null");
            this.splits = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "splits is null"));
        }

        public PlanNodeId getPlanNodeId() {
            return this.planNodeId;
        }

        public List<Split> getSplits() {
            return this.splits;
        }
    }

    /* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory$SingleDistributionTaskSource.class */
    public static class SingleDistributionTaskSource implements TaskSource {
        private final ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;
        private final DataSize taskMemory;
        private final InternalNodeManager nodeManager;
        private final boolean coordinatorOnly;
        private boolean finished;

        public static SingleDistributionTaskSource create(Session session, PlanFragment planFragment, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap, InternalNodeManager internalNodeManager, boolean z) {
            Preconditions.checkArgument(planFragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", planFragment.getPartitionedSources());
            return new SingleDistributionTaskSource(StageTaskSourceFactory.getInputsForRemoteSources(planFragment.getRemoteSourceNodes(), multimap), z ? SystemSessionProperties.getFaultTolerantExecutionDefaultCoordinatorTaskMemory(session) : SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory(session), internalNodeManager, z);
        }

        @VisibleForTesting
        SingleDistributionTaskSource(ListMultimap<PlanNodeId, ExchangeSourceHandle> listMultimap, DataSize dataSize, InternalNodeManager internalNodeManager, boolean z) {
            this.exchangeSourceHandles = ImmutableListMultimap.copyOf((Multimap) Objects.requireNonNull(listMultimap, "exchangeSourceHandles is null"));
            this.taskMemory = (DataSize) Objects.requireNonNull(dataSize, "taskMemory is null");
            this.nodeManager = (InternalNodeManager) Objects.requireNonNull(internalNodeManager, "nodeManager");
            this.coordinatorOnly = z;
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public ListenableFuture<List<TaskDescriptor>> getMoreTasks() {
            if (this.finished) {
                return Futures.immediateFuture(ImmutableList.of());
            }
            ImmutableSet of = ImmutableSet.of();
            if (this.coordinatorOnly) {
                InternalNode currentNode = this.nodeManager.getCurrentNode();
                Verify.verify(currentNode.isCoordinator(), "current node is expected to be a coordinator", new Object[0]);
                of = ImmutableSet.of(currentNode.getHostAndPort());
            }
            ImmutableList of2 = ImmutableList.of(new TaskDescriptor(0, ImmutableListMultimap.of(), this.exchangeSourceHandles, new NodeRequirements(Optional.empty(), of, this.taskMemory)));
            this.finished = true;
            return Futures.immediateFuture(of2);
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public boolean isFinished() {
            return this.finished;
        }

        @Override // io.trino.execution.scheduler.TaskSource, java.io.Closeable, java.lang.AutoCloseable
        public void close() {
        }
    }

    /* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory$SourceDistributionTaskSource.class */
    public static class SourceDistributionTaskSource implements TaskSource {
        private final QueryId queryId;
        private final PlanNodeId partitionedSourceNodeId;
        private final TableExecuteContextManager tableExecuteContextManager;
        private final SplitSource splitSource;
        private final ListMultimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles;
        private final int splitBatchSize;
        private final LongConsumer getSplitTimeRecorder;
        private final Optional<CatalogHandle> catalogRequirement;
        private final int minPartitionSplitCount;
        private final long targetPartitionSplitWeight;
        private final int maxPartitionSplitCount;
        private final DataSize taskMemory;
        private final Executor executor;

        @GuardedBy("this")
        private int currentPartitionId;

        @GuardedBy("this")
        private boolean finished;

        @GuardedBy("this")
        private boolean closed;

        @GuardedBy("this")
        private final Set<Split> remotelyAccessibleSplitBuffer = Sets.newIdentityHashSet();

        @GuardedBy("this")
        private final Map<HostAddress, Set<Split>> locallyAccessibleSplitBuffer = new HashMap();

        @GuardedBy("this")
        private ListenableFuture<SplitSource.SplitBatch> currentSplitBatchFuture = Futures.immediateFuture((Object) null);

        public static SourceDistributionTaskSource create(Session session, PlanFragment planFragment, SplitSourceFactory splitSourceFactory, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap, TableExecuteContextManager tableExecuteContextManager, int i, LongConsumer longConsumer, int i2, long j, int i3, Executor executor) {
            Preconditions.checkArgument(planFragment.getPartitionedSources().size() == 1, "single partitioned source is expected, got: %s", planFragment.getPartitionedSources());
            List<RemoteSourceNode> remoteSourceNodes = planFragment.getRemoteSourceNodes();
            Preconditions.checkArgument(remoteSourceNodes.stream().allMatch(remoteSourceNode -> {
                return remoteSourceNode.getExchangeType() == ExchangeNode.Type.REPLICATE;
            }), "only replicated exchanges are expected in source distributed stage, got: %s", remoteSourceNodes);
            PlanNodeId planNodeId = (PlanNodeId) Iterables.getOnlyElement(planFragment.getPartitionedSources());
            SplitSource splitSource = splitSourceFactory.createSplitSources(session, planFragment).get(planNodeId);
            return new SourceDistributionTaskSource(session.getQueryId(), planNodeId, tableExecuteContextManager, splitSource, StageTaskSourceFactory.getReplicatedExchangeSourceHandles(planFragment, multimap), i, longConsumer, Optional.of(splitSource.getCatalogHandle()).filter(catalogHandle -> {
                return !catalogHandle.getType().isInternal();
            }), i2, j, i3, SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory(session), executor);
        }

        @VisibleForTesting
        SourceDistributionTaskSource(QueryId queryId, PlanNodeId planNodeId, TableExecuteContextManager tableExecuteContextManager, SplitSource splitSource, ListMultimap<PlanNodeId, ExchangeSourceHandle> listMultimap, int i, LongConsumer longConsumer, Optional<CatalogHandle> optional, int i2, long j, int i3, DataSize dataSize, Executor executor) {
            this.queryId = (QueryId) Objects.requireNonNull(queryId, "queryId is null");
            this.partitionedSourceNodeId = (PlanNodeId) Objects.requireNonNull(planNodeId, "partitionedSourceNodeId is null");
            this.tableExecuteContextManager = (TableExecuteContextManager) Objects.requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null");
            this.splitSource = (SplitSource) Objects.requireNonNull(splitSource, "splitSource is null");
            this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf((Multimap) Objects.requireNonNull(listMultimap, "replicatedExchangeSourceHandles is null"));
            this.splitBatchSize = i;
            this.getSplitTimeRecorder = (LongConsumer) Objects.requireNonNull(longConsumer, "getSplitTimeRecorder is null");
            this.catalogRequirement = (Optional) Objects.requireNonNull(optional, "catalogRequirement is null");
            Preconditions.checkArgument(j > 0, "targetPartitionSplitCount must be greater than 0: %s", j);
            this.targetPartitionSplitWeight = j;
            Preconditions.checkArgument(i2 >= 0, "minPartitionSplitCount must be greater than or equal to 0: %s", i2);
            this.minPartitionSplitCount = i2;
            Preconditions.checkArgument(i3 > 0, "maxPartitionSplitCount must be greater than 0: %s", i3);
            Preconditions.checkArgument(i3 >= i2, "maxPartitionSplitCount(%s) must be greater than or equal to minPartitionSplitCount(%s)", i3, i2);
            this.maxPartitionSplitCount = i3;
            this.taskMemory = (DataSize) Objects.requireNonNull(dataSize, "taskMemory is null");
            this.executor = (Executor) Objects.requireNonNull(executor, "executor is null");
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public synchronized ListenableFuture<List<TaskDescriptor>> getMoreTasks() {
            if (this.finished || this.closed) {
                return Futures.immediateFuture(ImmutableList.of());
            }
            Preconditions.checkState(this.currentSplitBatchFuture.isDone(), "getMoreTasks called again before the previous batch of splits was ready");
            this.currentSplitBatchFuture = this.splitSource.getNextBatch(this.splitBatchSize);
            long nanoTime = System.nanoTime();
            MoreFutures.addSuccessCallback(this.currentSplitBatchFuture, () -> {
                this.getSplitTimeRecorder.accept(nanoTime);
            });
            return Futures.transform(this.currentSplitBatchFuture, splitBatch -> {
                ImmutableList build;
                synchronized (this) {
                    for (Split split : splitBatch.getSplits()) {
                        if (split.isRemotelyAccessible()) {
                            this.remotelyAccessibleSplitBuffer.add(split);
                        } else {
                            List<HostAddress> addresses = split.getAddresses();
                            Preconditions.checkArgument(!addresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty");
                            Iterator<HostAddress> it = addresses.iterator();
                            while (it.hasNext()) {
                                this.locallyAccessibleSplitBuffer.computeIfAbsent(it.next(), hostAddress -> {
                                    return Sets.newIdentityHashSet();
                                }).add(split);
                            }
                        }
                    }
                    ImmutableList.Builder builder = ImmutableList.builder();
                    boolean isLastBatch = splitBatch.isLastBatch();
                    builder.addAll(getReadyTasks(this.remotelyAccessibleSplitBuffer, ImmutableList.of(), new NodeRequirements(this.catalogRequirement, ImmutableSet.of(), this.taskMemory), isLastBatch));
                    for (HostAddress hostAddress2 : this.locallyAccessibleSplitBuffer.keySet()) {
                        builder.addAll(getReadyTasks(this.locallyAccessibleSplitBuffer.get(hostAddress2), (List) this.locallyAccessibleSplitBuffer.entrySet().stream().filter(entry -> {
                            return !((HostAddress) entry.getKey()).equals(hostAddress2);
                        }).map((v0) -> {
                            return v0.getValue();
                        }).collect(ImmutableList.toImmutableList()), new NodeRequirements(this.catalogRequirement, ImmutableSet.of(hostAddress2), this.taskMemory), isLastBatch));
                    }
                    build = builder.build();
                    if (isLastBatch) {
                        this.splitSource.getTableExecuteSplitsInfo().ifPresent(list -> {
                            this.tableExecuteContextManager.getTableExecuteContextForQuery(this.queryId).setSplitsInfo(list);
                        });
                        try {
                            this.splitSource.close();
                        } catch (RuntimeException e) {
                            StageTaskSourceFactory.log.error(e, "Error closing split source");
                        }
                        this.finished = true;
                    }
                }
                return build;
            }, this.executor);
        }

        private List<TaskDescriptor> getReadyTasks(Set<Split> set, List<Set<Split>> list, NodeRequirements nodeRequirements, boolean z) {
            ImmutableList.Builder builder = ImmutableList.builder();
            while (true) {
                Optional<TaskDescriptor> readyTask = getReadyTask(set, list, nodeRequirements);
                if (readyTask.isEmpty()) {
                    break;
                }
                builder.add(readyTask.get());
            }
            if (z && !set.isEmpty()) {
                builder.add(buildTaskDescriptor(set, nodeRequirements));
                Iterator<Set<Split>> it = list.iterator();
                while (it.hasNext()) {
                    it.next().removeAll(set);
                }
                set.clear();
            }
            return builder.build();
        }

        private Optional<TaskDescriptor> getReadyTask(Set<Split> set, List<Set<Split>> list, NodeRequirements nodeRequirements) {
            ImmutableList.Builder builder = ImmutableList.builder();
            int i = 0;
            int i2 = 0;
            for (Split split : set) {
                i2 = (int) (i2 + split.getSplitWeight().getRawValue());
                i++;
                builder.add(split);
                if (i >= this.minPartitionSplitCount && (i2 >= this.targetPartitionSplitWeight || i >= this.maxPartitionSplitCount)) {
                    ImmutableList build = builder.build();
                    for (Set<Split> set2 : list) {
                        Objects.requireNonNull(set2);
                        build.forEach((v1) -> {
                            r1.remove(v1);
                        });
                    }
                    Objects.requireNonNull(set);
                    build.forEach((v1) -> {
                        r1.remove(v1);
                    });
                    return Optional.of(buildTaskDescriptor(build, nodeRequirements));
                }
            }
            return Optional.empty();
        }

        private synchronized TaskDescriptor buildTaskDescriptor(Collection<Split> collection, NodeRequirements nodeRequirements) {
            int i = this.currentPartitionId;
            this.currentPartitionId = i + 1;
            return new TaskDescriptor(i, ImmutableListMultimap.builder().putAll(this.partitionedSourceNodeId, collection).build(), this.replicatedExchangeSourceHandles, nodeRequirements);
        }

        @Override // io.trino.execution.scheduler.TaskSource
        public synchronized boolean isFinished() {
            return this.finished;
        }

        @Override // io.trino.execution.scheduler.TaskSource, java.io.Closeable, java.lang.AutoCloseable
        public synchronized void close() {
            if (this.closed) {
                return;
            }
            this.closed = true;
            this.splitSource.close();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/StageTaskSourceFactory$SplitLoadingFuture.class */
    public static class SplitLoadingFuture extends AbstractFuture<LoadedSplits> {
        private final PlanNodeId planNodeId;
        private final SplitSource splitSource;
        private final int splitBatchSize;
        private final LongConsumer getSplitTimeRecorder;
        private final Executor executor;

        @GuardedBy("this")
        private final List<Split> loadedSplits = new ArrayList();

        @GuardedBy("this")
        private ListenableFuture<SplitSource.SplitBatch> currentSplitBatch = Futures.immediateFuture((Object) null);

        SplitLoadingFuture(PlanNodeId planNodeId, SplitSource splitSource, int i, LongConsumer longConsumer, Executor executor) {
            this.planNodeId = (PlanNodeId) Objects.requireNonNull(planNodeId, "planNodeId is null");
            this.splitSource = (SplitSource) Objects.requireNonNull(splitSource, "splitSource is null");
            this.splitBatchSize = i;
            this.getSplitTimeRecorder = (LongConsumer) Objects.requireNonNull(longConsumer, "getSplitTimeRecorder is null");
            this.executor = (Executor) Objects.requireNonNull(executor, "executor is null");
        }

        public synchronized void load() {
            if (this.currentSplitBatch == null) {
                Preconditions.checkState(isCancelled(), "SplitLoadingFuture should be in cancelled state");
                return;
            }
            Preconditions.checkState(this.currentSplitBatch.isDone(), "next batch of splits requested before previous batch is done");
            this.currentSplitBatch = this.splitSource.getNextBatch(this.splitBatchSize);
            final long nanoTime = System.nanoTime();
            Futures.addCallback(this.currentSplitBatch, new FutureCallback<SplitSource.SplitBatch>() { // from class: io.trino.execution.scheduler.StageTaskSourceFactory.SplitLoadingFuture.1
                public void onSuccess(SplitSource.SplitBatch splitBatch) {
                    SplitLoadingFuture.this.getSplitTimeRecorder.accept(nanoTime);
                    synchronized (SplitLoadingFuture.this) {
                        SplitLoadingFuture.this.loadedSplits.addAll(splitBatch.getSplits());
                        if (splitBatch.isLastBatch()) {
                            SplitLoadingFuture.this.set(new LoadedSplits(SplitLoadingFuture.this.planNodeId, SplitLoadingFuture.this.loadedSplits));
                            try {
                                SplitLoadingFuture.this.splitSource.close();
                            } catch (RuntimeException e) {
                                StageTaskSourceFactory.log.error(e, "Error closing split source");
                            }
                        } else {
                            SplitLoadingFuture.this.load();
                        }
                    }
                }

                public void onFailure(Throwable th) {
                    SplitLoadingFuture.this.setException(th);
                }
            }, this.executor);
        }

        protected synchronized void interruptTask() {
            if (this.currentSplitBatch != null) {
                this.currentSplitBatch.cancel(true);
                this.currentSplitBatch = null;
            }
        }
    }

    @Inject
    public StageTaskSourceFactory(SplitSourceFactory splitSourceFactory, TableExecuteContextManager tableExecuteContextManager, QueryManagerConfig queryManagerConfig, @ForQueryExecution ExecutorService executorService, InternalNodeManager internalNodeManager) {
        this(splitSourceFactory, tableExecuteContextManager, queryManagerConfig.getScheduleSplitBatchSize(), executorService, internalNodeManager);
    }

    public StageTaskSourceFactory(SplitSourceFactory splitSourceFactory, TableExecuteContextManager tableExecuteContextManager, int i, ExecutorService executorService, InternalNodeManager internalNodeManager) {
        this.splitSourceFactory = (SplitSourceFactory) Objects.requireNonNull(splitSourceFactory, "splitSourceFactory is null");
        this.tableExecuteContextManager = (TableExecuteContextManager) Objects.requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null");
        this.splitBatchSize = i;
        this.executor = (Executor) Objects.requireNonNull(executorService, "executor is null");
        this.nodeManager = (InternalNodeManager) Objects.requireNonNull(internalNodeManager, "nodeManager is null");
    }

    @Override // io.trino.execution.scheduler.TaskSourceFactory
    public TaskSource create(Session session, PlanFragment planFragment, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap, LongConsumer longConsumer, Optional<int[]> optional, Optional<BucketNodeMap> optional2) {
        PartitioningHandle partitioning = planFragment.getPartitioning();
        if (partitioning.equals(SystemPartitioningHandle.SINGLE_DISTRIBUTION) || partitioning.equals(SystemPartitioningHandle.COORDINATOR_DISTRIBUTION)) {
            return SingleDistributionTaskSource.create(session, planFragment, multimap, this.nodeManager, partitioning.equals(SystemPartitioningHandle.COORDINATOR_DISTRIBUTION));
        }
        if (partitioning.equals(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION)) {
            return ArbitraryDistributionTaskSource.create(session, planFragment, multimap, SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize(session));
        }
        if (partitioning.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) {
            return HashDistributionTaskSource.create(session, planFragment, this.splitSourceFactory, multimap, this.splitBatchSize, longConsumer, optional.orElseThrow(() -> {
                return new IllegalArgumentException("bucketToPartitionMap is expected to be present for hash distributed stages");
            }), optional2, SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize(session), SystemSessionProperties.getFaultTolerantPreserveInputPartitionsInWriteStage(session), this.executor);
        }
        if (partitioning.equals(SystemPartitioningHandle.SOURCE_DISTRIBUTION)) {
            return SourceDistributionTaskSource.create(session, planFragment, this.splitSourceFactory, multimap, this.tableExecuteContextManager, this.splitBatchSize, longConsumer, SystemSessionProperties.getFaultTolerantExecutionMinTaskSplitCount(session), SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), SystemSessionProperties.getFaultTolerantExecutionMaxTaskSplitCount(session), this.executor);
        }
        throw new IllegalArgumentException("Unexpected partitioning: " + partitioning);
    }

    private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getReplicatedExchangeSourceHandles(PlanFragment planFragment, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap) {
        return getInputsForRemoteSources((List) planFragment.getRemoteSourceNodes().stream().filter(remoteSourceNode -> {
            return remoteSourceNode.getExchangeType() == ExchangeNode.Type.REPLICATE;
        }).collect(ImmutableList.toImmutableList()), multimap);
    }

    private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getPartitionedExchangeSourceHandles(PlanFragment planFragment, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap) {
        return getInputsForRemoteSources((List) planFragment.getRemoteSourceNodes().stream().filter(remoteSourceNode -> {
            return remoteSourceNode.getExchangeType() != ExchangeNode.Type.REPLICATE;
        }).collect(ImmutableList.toImmutableList()), multimap);
    }

    private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getInputsForRemoteSources(List<RemoteSourceNode> list, Multimap<PlanFragmentId, ExchangeSourceHandle> multimap) {
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        for (RemoteSourceNode remoteSourceNode : list) {
            for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
                builder.putAll(remoteSourceNode.getId(), (Collection) Objects.requireNonNull(multimap.get(planFragmentId), (Supplier<String>) () -> {
                    return "exchange source handle is missing for fragment: " + planFragmentId;
                }));
            }
        }
        return builder.build();
    }
}
