package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.base.Ticker;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.UnmodifiableIterator;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.MoreFutures;
import io.airlift.concurrent.SetThreadName;
import io.airlift.log.Logger;
import io.airlift.stats.TimeStat;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.exchange.SpoolingExchangeInput;
import io.trino.execution.BasicStageStats;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.QueryState;
import io.trino.execution.QueryStateMachine;
import io.trino.execution.RemoteTaskFactory;
import io.trino.execution.SqlStage;
import io.trino.execution.StageId;
import io.trino.execution.StageInfo;
import io.trino.execution.TaskId;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Metadata;
import io.trino.operator.RetryPolicy;
import io.trino.server.DynamicFilterService;
import io.trino.spi.exchange.Exchange;
import io.trino.spi.exchange.ExchangeContext;
import io.trino.spi.exchange.ExchangeId;
import io.trino.spi.exchange.ExchangeManager;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.PlanFragmentId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.IntStream;
import javax.annotation.concurrent.GuardedBy;

/* loaded from: input_file:io/trino/execution/scheduler/FaultTolerantQueryScheduler.class */
public class FaultTolerantQueryScheduler implements QueryScheduler {
    private static final Logger log = Logger.get(FaultTolerantQueryScheduler.class);
    private final QueryStateMachine queryStateMachine;
    private final ExecutorService queryExecutor;
    private final SplitSchedulerStats schedulerStats;
    private final FailureDetector failureDetector;
    private final TaskSourceFactory taskSourceFactory;
    private final TaskDescriptorStorage taskDescriptorStorage;
    private final ExchangeManager exchangeManager;
    private final NodePartitioningManager nodePartitioningManager;
    private final int taskRetryAttemptsOverall;
    private final int taskRetryAttemptsPerTask;
    private final int maxTasksWaitingForNodePerStage;
    private final ScheduledExecutorService scheduledExecutorService;
    private final NodeAllocatorService nodeAllocatorService;
    private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory;
    private final TaskExecutionStats taskExecutionStats;
    private final DynamicFilterService dynamicFilterService;
    private final StageManager stageManager;

    @GuardedBy("this")
    private boolean started;

    @GuardedBy("this")
    private Scheduler scheduler;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/FaultTolerantQueryScheduler$BucketToPartition.class */
    public static class BucketToPartition {
        private final Optional<int[]> bucketToPartitionMap;
        private final Optional<BucketNodeMap> bucketNodeMap;

        private BucketToPartition(Optional<int[]> optional, Optional<BucketNodeMap> optional2) {
            this.bucketToPartitionMap = (Optional) Objects.requireNonNull(optional, "bucketToPartitionMap is null");
            this.bucketNodeMap = (Optional) Objects.requireNonNull(optional2, "bucketNodeMap is null");
        }

        public Optional<int[]> getBucketToPartitionMap() {
            return this.bucketToPartitionMap;
        }

        public Optional<BucketNodeMap> getBucketNodeMap() {
            return this.bucketNodeMap;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/FaultTolerantQueryScheduler$Scheduler.class */
    public static class Scheduler {
        private final QueryStateMachine queryStateMachine;
        private final List<FaultTolerantStageScheduler> schedulers;
        private final StageManager stageManager;
        private final SplitSchedulerStats schedulerStats;
        private final NodeAllocator nodeAllocator;

        private Scheduler(QueryStateMachine queryStateMachine, List<FaultTolerantStageScheduler> list, StageManager stageManager, SplitSchedulerStats splitSchedulerStats, NodeAllocator nodeAllocator) {
            this.queryStateMachine = (QueryStateMachine) Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
            this.stageManager = (StageManager) Objects.requireNonNull(stageManager, "stageManager is null");
            this.schedulers = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "schedulers is null"));
            this.schedulerStats = (SplitSchedulerStats) Objects.requireNonNull(splitSchedulerStats, "schedulerStats is null");
            this.nodeAllocator = (NodeAllocator) Objects.requireNonNull(nodeAllocator, "nodeAllocator is null");
        }

        public void schedule() {
            if (this.schedulers.isEmpty()) {
                this.queryStateMachine.transitionToFinishing();
                return;
            }
            this.queryStateMachine.transitionToRunning();
            try {
                SetThreadName setThreadName = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});
                try {
                    ArrayList arrayList = new ArrayList();
                    while (!FaultTolerantQueryScheduler.isFinishingOrDone(this.queryStateMachine)) {
                        arrayList.clear();
                        boolean z = false;
                        boolean z2 = true;
                        for (FaultTolerantStageScheduler faultTolerantStageScheduler : this.schedulers) {
                            if (faultTolerantStageScheduler.isFinished()) {
                                this.stageManager.get(faultTolerantStageScheduler.getStageId()).finish();
                            } else {
                                z2 = false;
                                ListenableFuture<Void> isBlocked = faultTolerantStageScheduler.isBlocked();
                                if (isBlocked.isDone()) {
                                    try {
                                        faultTolerantStageScheduler.schedule();
                                        ListenableFuture<Void> isBlocked2 = faultTolerantStageScheduler.isBlocked();
                                        if (isBlocked2.isDone()) {
                                            z = true;
                                        } else {
                                            arrayList.add(isBlocked2);
                                        }
                                    } catch (Throwable th) {
                                        fail(th, Optional.of(faultTolerantStageScheduler.getStageId()));
                                        setThreadName.close();
                                        return;
                                    }
                                } else {
                                    arrayList.add(isBlocked);
                                }
                            }
                        }
                        if (z2) {
                            this.queryStateMachine.transitionToFinishing();
                            setThreadName.close();
                            return;
                        }
                        if (!z) {
                            Verify.verify(!arrayList.isEmpty(), "blockedStages is not expected to be empty here", new Object[0]);
                            TimeStat.BlockTimer time = this.schedulerStats.getSleepTime().time();
                            try {
                                try {
                                    MoreFutures.tryGetFutureValue(MoreFutures.whenAnyComplete(arrayList), 1, TimeUnit.SECONDS);
                                } catch (Throwable th2) {
                                    if (time != null) {
                                        try {
                                            time.close();
                                        } catch (Throwable th3) {
                                            th2.addSuppressed(th3);
                                        }
                                    }
                                    throw th2;
                                }
                            } catch (CancellationException e) {
                                FaultTolerantQueryScheduler.log.debug("Scheduling has been cancelled for query %s. Query state: %s", new Object[]{this.queryStateMachine.getQueryId(), this.queryStateMachine.getQueryState()});
                            }
                            if (time != null) {
                                time.close();
                            }
                        }
                    }
                    setThreadName.close();
                } finally {
                }
            } catch (Throwable th4) {
                fail(th4, Optional.empty());
            }
        }

        public void cancel() {
            this.schedulers.forEach((v0) -> {
                v0.cancel();
            });
            closeNodeAllocator();
        }

        public void abort() {
            this.schedulers.forEach((v0) -> {
                v0.abort();
            });
            closeNodeAllocator();
        }

        private void fail(Throwable th, Optional<StageId> optional) {
            abort();
            this.stageManager.getStagesInTopologicalOrder().forEach(sqlStage -> {
                if (optional.isPresent() && ((StageId) optional.get()).equals(sqlStage.getStageId())) {
                    sqlStage.fail(th);
                } else {
                    sqlStage.abort();
                }
            });
            this.queryStateMachine.transitionToFailed(th);
        }

        private void closeNodeAllocator() {
            try {
                this.nodeAllocator.close();
            } catch (Throwable th) {
                FaultTolerantQueryScheduler.log.warn(th, "Error closing node allocator for query: %s", new Object[]{this.queryStateMachine.getQueryId()});
            }
        }
    }

    public FaultTolerantQueryScheduler(QueryStateMachine queryStateMachine, ExecutorService executorService, SplitSchedulerStats splitSchedulerStats, FailureDetector failureDetector, TaskSourceFactory taskSourceFactory, TaskDescriptorStorage taskDescriptorStorage, ExchangeManager exchangeManager, NodePartitioningManager nodePartitioningManager, int i, int i2, int i3, ScheduledExecutorService scheduledExecutorService, NodeAllocatorService nodeAllocatorService, PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory, TaskExecutionStats taskExecutionStats, DynamicFilterService dynamicFilterService, Metadata metadata, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, SubPlan subPlan, boolean z) {
        this.queryStateMachine = (QueryStateMachine) Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
        RetryPolicy retryPolicy = SystemSessionProperties.getRetryPolicy(queryStateMachine.getSession());
        Verify.verify(retryPolicy == RetryPolicy.TASK, "unexpected retry policy: %s", retryPolicy);
        this.queryExecutor = (ExecutorService) Objects.requireNonNull(executorService, "queryExecutor is null");
        this.schedulerStats = (SplitSchedulerStats) Objects.requireNonNull(splitSchedulerStats, "schedulerStats is null");
        this.failureDetector = (FailureDetector) Objects.requireNonNull(failureDetector, "failureDetector is null");
        this.taskSourceFactory = (TaskSourceFactory) Objects.requireNonNull(taskSourceFactory, "taskSourceFactory is null");
        this.taskDescriptorStorage = (TaskDescriptorStorage) Objects.requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
        this.exchangeManager = (ExchangeManager) Objects.requireNonNull(exchangeManager, "exchangeManager is null");
        this.nodePartitioningManager = (NodePartitioningManager) Objects.requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
        this.taskRetryAttemptsOverall = i;
        this.taskRetryAttemptsPerTask = i2;
        this.maxTasksWaitingForNodePerStage = i3;
        this.scheduledExecutorService = (ScheduledExecutorService) Objects.requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
        this.nodeAllocatorService = (NodeAllocatorService) Objects.requireNonNull(nodeAllocatorService, "nodeAllocatorService is null");
        this.partitionMemoryEstimatorFactory = (PartitionMemoryEstimatorFactory) Objects.requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null");
        this.taskExecutionStats = (TaskExecutionStats) Objects.requireNonNull(taskExecutionStats, "taskExecutionStats is null");
        this.dynamicFilterService = (DynamicFilterService) Objects.requireNonNull(dynamicFilterService, "dynamicFilterService is null");
        this.stageManager = StageManager.create(queryStateMachine, metadata, remoteTaskFactory, nodeTaskMap, executorService, splitSchedulerStats, subPlan, z);
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public synchronized void start() {
        if (this.started) {
            return;
        }
        this.started = true;
        if (this.queryStateMachine.isDone()) {
            return;
        }
        this.queryStateMachine.addStateChangeListener(queryState -> {
            Scheduler scheduler;
            if (queryState.isDone()) {
                synchronized (this) {
                    scheduler = this.scheduler;
                    this.scheduler = null;
                }
                if (queryState == QueryState.FINISHED) {
                    if (scheduler != null) {
                        scheduler.cancel();
                    }
                    this.stageManager.finish();
                } else if (queryState == QueryState.FAILED) {
                    if (scheduler != null) {
                        scheduler.abort();
                    }
                    this.stageManager.abort();
                }
                this.queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo()));
            }
        });
        this.scheduler = createScheduler();
        ExecutorService executorService = this.queryExecutor;
        Scheduler scheduler = this.scheduler;
        Objects.requireNonNull(scheduler);
        executorService.submit(scheduler::schedule);
    }

    private Scheduler createScheduler() {
        this.taskDescriptorStorage.initialize(this.queryStateMachine.getQueryId());
        this.queryStateMachine.addStateChangeListener(queryState -> {
            if (queryState.isDone()) {
                this.taskDescriptorStorage.destroy(this.queryStateMachine.getQueryId());
            }
        });
        Session session = this.queryStateMachine.getSession();
        int faultTolerantExecutionPartitionCount = SystemSessionProperties.getFaultTolerantExecutionPartitionCount(session);
        Function<PartitioningHandle, BucketToPartition> createBucketToPartitionCache = createBucketToPartitionCache(this.nodePartitioningManager, session, faultTolerantExecutionPartitionCount);
        ImmutableList.Builder builder = ImmutableList.builder();
        HashMap hashMap = new HashMap();
        NodeAllocator nodeAllocator = this.nodeAllocatorService.getNodeAllocator(session);
        try {
            List<SqlStage> reverse = Lists.reverse(this.stageManager.getStagesInTopologicalOrder());
            Preconditions.checkArgument(this.taskRetryAttemptsOverall >= 0, "taskRetryAttemptsOverall must be greater than or equal to 0: %s", this.taskRetryAttemptsOverall);
            AtomicInteger atomicInteger = new AtomicInteger(this.taskRetryAttemptsOverall);
            ArrayList arrayList = new ArrayList();
            for (SqlStage sqlStage : reverse) {
                PlanFragment fragment = sqlStage.getFragment();
                boolean equals = this.stageManager.getOutputStage().getStageId().equals(sqlStage.getStageId());
                Exchange createExchange = this.exchangeManager.createExchange(new ExchangeContext(session.getQueryId(), new ExchangeId("external-exchange-" + sqlStage.getStageId().getId())), faultTolerantExecutionPartitionCount, equals);
                hashMap.put(fragment.getId(), createExchange);
                if (equals) {
                    arrayList.add(createExchange);
                }
                ImmutableMap.Builder builder2 = ImmutableMap.builder();
                Iterator<SqlStage> it = this.stageManager.getChildren(fragment.getId()).iterator();
                while (it.hasNext()) {
                    PlanFragmentId id = it.next().getFragment().getId();
                    Exchange exchange = (Exchange) hashMap.get(id);
                    Verify.verify(exchange != null, "exchange not found for fragment: %s", id);
                    builder2.put(id, exchange);
                }
                BucketToPartition apply = createBucketToPartitionCache.apply(fragment.getPartitioning());
                builder.add(new FaultTolerantStageScheduler(session, sqlStage, this.failureDetector, this.taskSourceFactory, nodeAllocator, this.taskDescriptorStorage, this.partitionMemoryEstimatorFactory.createPartitionMemoryEstimator(), this.taskExecutionStats, (settableFuture, duration) -> {
                    this.scheduledExecutorService.schedule(() -> {
                        return Boolean.valueOf(settableFuture.set((Object) null));
                    }, duration.toMillis(), TimeUnit.MILLISECONDS);
                }, Ticker.systemTicker(), createExchange, createBucketToPartitionCache.apply(fragment.getPartitioningScheme().getPartitioning().getHandle()).getBucketToPartitionMap(), builder2.buildOrThrow(), apply.getBucketToPartitionMap(), apply.getBucketNodeMap(), atomicInteger, this.taskRetryAttemptsPerTask, this.maxTasksWaitingForNodePerStage, this.dynamicFilterService));
            }
            if (!reverse.isEmpty()) {
                Verify.verify(!arrayList.isEmpty(), "coordinatorConsumedExchanges is empty", new Object[0]);
                ListenableFuture allAsList = Futures.allAsList((List) arrayList.stream().map((v0) -> {
                    return v0.getSourceHandles();
                }).map(Exchanges::getAllSourceHandles).collect(ImmutableList.toImmutableList()));
                MoreFutures.addSuccessCallback(allAsList, list -> {
                    List list = (List) list.stream().flatMap((v0) -> {
                        return v0.stream();
                    }).collect(ImmutableList.toImmutableList());
                    ImmutableList.Builder builder3 = ImmutableList.builder();
                    if (!list.isEmpty()) {
                        builder3.add(new SpoolingExchangeInput(list));
                    }
                    this.queryStateMachine.updateInputsForQueryResults(builder3.build(), true);
                });
                QueryStateMachine queryStateMachine = this.queryStateMachine;
                Objects.requireNonNull(queryStateMachine);
                MoreFutures.addExceptionCallback(allAsList, queryStateMachine::transitionToFailed);
            }
            return new Scheduler(this.queryStateMachine, builder.build(), this.stageManager, this.schedulerStats, nodeAllocator);
        } catch (Throwable th) {
            UnmodifiableIterator it2 = builder.build().iterator();
            while (it2.hasNext()) {
                try {
                    ((FaultTolerantStageScheduler) it2.next()).abort();
                } catch (Throwable th2) {
                    if (th != th2) {
                        th.addSuppressed(th2);
                    }
                }
            }
            try {
                nodeAllocator.close();
            } catch (Throwable th3) {
                if (th != th3) {
                    th.addSuppressed(th3);
                }
            }
            Iterator it3 = hashMap.values().iterator();
            while (it3.hasNext()) {
                try {
                    ((Exchange) it3.next()).close();
                } catch (Throwable th4) {
                    if (th != th4) {
                        th.addSuppressed(th4);
                    }
                }
            }
            throw th;
        }
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public void cancelStage(StageId stageId) {
        throw new UnsupportedOperationException("partial cancel is not supported in fault tolerant mode");
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public void failTask(TaskId taskId, Throwable th) {
        this.stageManager.failTaskRemotely(taskId, th);
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public BasicStageStats getBasicStageStats() {
        return this.stageManager.getBasicStageStats();
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public StageInfo getStageInfo() {
        return this.stageManager.getStageInfo();
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public long getUserMemoryReservation() {
        return this.stageManager.getUserMemoryReservation();
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public long getTotalMemoryReservation() {
        return this.stageManager.getTotalMemoryReservation();
    }

    @Override // io.trino.execution.scheduler.QueryScheduler
    public Duration getTotalCpuTime() {
        return this.stageManager.getTotalCpuTime();
    }

    private static Function<PartitioningHandle, BucketToPartition> createBucketToPartitionCache(NodePartitioningManager nodePartitioningManager, Session session, int i) {
        HashMap hashMap = new HashMap();
        return partitioningHandle -> {
            return (BucketToPartition) hashMap.computeIfAbsent(partitioningHandle, partitioningHandle -> {
                return createBucketToPartitionMap(session, i, partitioningHandle, nodePartitioningManager);
            });
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BucketToPartition createBucketToPartitionMap(Session session, int i, PartitioningHandle partitioningHandle, NodePartitioningManager nodePartitioningManager) {
        if (partitioningHandle.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION)) {
            return new BucketToPartition(Optional.of(IntStream.range(0, i).toArray()), Optional.empty());
        }
        if (!partitioningHandle.getCatalogHandle().isPresent()) {
            return new BucketToPartition(Optional.empty(), Optional.empty());
        }
        BucketNodeMap bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
        int bucketCount = bucketNodeMap.getBucketCount();
        int[] iArr = new int[bucketCount];
        HashMap hashMap = new HashMap();
        int i2 = 0;
        for (int i3 = 0; i3 < bucketCount; i3++) {
            InternalNode assignedNode = bucketNodeMap.getAssignedNode(i3);
            Integer num = (Integer) hashMap.get(assignedNode);
            if (num == null) {
                num = Integer.valueOf(i2);
                i2++;
                hashMap.put(assignedNode, num);
            }
            iArr[i3] = num.intValue();
        }
        return new BucketToPartition(Optional.of(iArr), Optional.of(bucketNodeMap));
    }

    private static boolean isFinishingOrDone(QueryStateMachine queryStateMachine) {
        QueryState queryState = queryStateMachine.getQueryState();
        return queryState == QueryState.FINISHING || queryState.isDone();
    }
}
