package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.log.Logger;
import io.trino.execution.ExecutionFailureInfo;
import io.trino.execution.Lifespan;
import io.trino.execution.RemoteTask;
import io.trino.execution.SqlStage;
import io.trino.execution.StageId;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskId;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.scheduler.StageExecution;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.operator.ExchangeOperator;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.split.RemoteSplit;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.util.Failures;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.concurrent.GuardedBy;
import org.joda.time.DateTime;

/* loaded from: input_file:io/trino/execution/scheduler/PipelinedStageExecution.class */
public class PipelinedStageExecution implements StageExecution {
    private static final Logger log = Logger.get(PipelinedStageExecution.class);
    private final PipelinedStageStateMachine stateMachine;
    private final SqlStage stage;
    private final Map<PlanFragmentId, OutputBufferManager> outputBufferManagers;
    private final TaskLifecycleListener taskLifecycleListener;
    private final FailureDetector failureDetector;
    private final Executor executor;
    private final Optional<int[]> bucketToPartition;
    private final Map<PlanFragmentId, RemoteSourceNode> exchangeSources;
    private final int attempt;
    private final Map<Integer, RemoteTask> tasks = new ConcurrentHashMap();

    @GuardedBy("this")
    private final Set<TaskId> allTasks = new HashSet();

    @GuardedBy("this")
    private final Set<TaskId> finishedTasks = new HashSet();

    @GuardedBy("this")
    private final Set<TaskId> flushingTasks = new HashSet();

    @GuardedBy("this")
    private final Multimap<PlanFragmentId, RemoteTask> sourceTasks = HashMultimap.create();

    @GuardedBy("this")
    private final Set<PlanFragmentId> completeSourceFragments = new HashSet();

    @GuardedBy("this")
    private final Set<PlanNodeId> completeSources = new HashSet();
    private final Set<Lifespan> completedDriverGroups = new HashSet();
    private final ListenerManager<Set<Lifespan>> completedLifespansChangeListeners = new ListenerManager<>();

    /* loaded from: input_file:io/trino/execution/scheduler/PipelinedStageExecution$ListenerManager.class */
    private static class ListenerManager<T> {
        private final List<Consumer<T>> listeners = new ArrayList();
        private boolean frozen;

        private ListenerManager() {
        }

        public synchronized void addListener(Consumer<T> consumer) {
            Preconditions.checkState(!this.frozen, "Listeners have been invoked");
            this.listeners.add(consumer);
        }

        public synchronized void invoke(T t, Executor executor) {
            this.frozen = true;
            for (Consumer<T> consumer : this.listeners) {
                executor.execute(() -> {
                    consumer.accept(t);
                });
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/PipelinedStageExecution$PipelinedStageStateMachine.class */
    public static class PipelinedStageStateMachine {
        private static final Set<StageExecution.State> TERMINAL_STAGE_STATES = (Set) Stream.of((Object[]) StageExecution.State.values()).filter((v0) -> {
            return v0.isDone();
        }).collect(ImmutableSet.toImmutableSet());
        private final StageId stageId;
        private final StateMachine<StageExecution.State> state;
        private final AtomicReference<DateTime> schedulingComplete = new AtomicReference<>();
        private final AtomicReference<ExecutionFailureInfo> failureCause = new AtomicReference<>();

        private PipelinedStageStateMachine(StageId stageId, Executor executor) {
            this.stageId = (StageId) Objects.requireNonNull(stageId, "stageId is null");
            this.state = new StateMachine<>("Pipelined stage execution " + stageId, executor, StageExecution.State.PLANNED, TERMINAL_STAGE_STATES);
            this.state.addStateChangeListener(state -> {
                PipelinedStageExecution.log.debug("Pipelined stage execution %s is %s", new Object[]{stageId, state});
            });
        }

        public StageExecution.State getState() {
            return this.state.get();
        }

        public boolean transitionToScheduling() {
            return this.state.compareAndSet(StageExecution.State.PLANNED, StageExecution.State.SCHEDULING);
        }

        public boolean transitionToSchedulingSplits() {
            return this.state.setIf(StageExecution.State.SCHEDULING_SPLITS, state -> {
                return state == StageExecution.State.PLANNED || state == StageExecution.State.SCHEDULING;
            });
        }

        public boolean transitionToScheduled() {
            this.schedulingComplete.compareAndSet(null, DateTime.now());
            return this.state.setIf(StageExecution.State.SCHEDULED, state -> {
                return state == StageExecution.State.PLANNED || state == StageExecution.State.SCHEDULING || state == StageExecution.State.SCHEDULING_SPLITS;
            });
        }

        public boolean transitionToRunning() {
            return this.state.setIf(StageExecution.State.RUNNING, state -> {
                return (state == StageExecution.State.RUNNING || state == StageExecution.State.FLUSHING || state.isDone()) ? false : true;
            });
        }

        public boolean transitionToFlushing() {
            return this.state.setIf(StageExecution.State.FLUSHING, state -> {
                return (state == StageExecution.State.FLUSHING || state.isDone()) ? false : true;
            });
        }

        public boolean transitionToFinished() {
            return this.state.setIf(StageExecution.State.FINISHED, state -> {
                return !state.isDone();
            });
        }

        public boolean transitionToCanceled() {
            return this.state.setIf(StageExecution.State.CANCELED, state -> {
                return !state.isDone();
            });
        }

        public boolean transitionToAborted() {
            return this.state.setIf(StageExecution.State.ABORTED, state -> {
                return !state.isDone();
            });
        }

        public boolean transitionToFailed(Throwable th) {
            Objects.requireNonNull(th, "throwable is null");
            this.failureCause.compareAndSet(null, Failures.toFailure(th));
            boolean z = this.state.setIf(StageExecution.State.FAILED, state -> {
                return !state.isDone();
            });
            if (z) {
                PipelinedStageExecution.log.error(th, "Pipelined stage execution for stage %s failed", new Object[]{this.stageId});
            } else {
                PipelinedStageExecution.log.debug(th, "Failure in pipelined stage execution for stage %s after finished", new Object[]{this.stageId});
            }
            return z;
        }

        public Optional<ExecutionFailureInfo> getFailureCause() {
            return Optional.ofNullable(this.failureCause.get());
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<StageExecution.State> stateChangeListener) {
            this.state.addStateChangeListener(stateChangeListener);
        }
    }

    public static PipelinedStageExecution createPipelinedStageExecution(SqlStage sqlStage, Map<PlanFragmentId, OutputBufferManager> map, TaskLifecycleListener taskLifecycleListener, FailureDetector failureDetector, Executor executor, Optional<int[]> optional, int i) {
        PipelinedStageStateMachine pipelinedStageStateMachine = new PipelinedStageStateMachine(sqlStage.getStageId(), executor);
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (RemoteSourceNode remoteSourceNode : sqlStage.getFragment().getRemoteSourceNodes()) {
            Iterator<PlanFragmentId> it = remoteSourceNode.getSourceFragmentIds().iterator();
            while (it.hasNext()) {
                builder.put(it.next(), remoteSourceNode);
            }
        }
        PipelinedStageExecution pipelinedStageExecution = new PipelinedStageExecution(pipelinedStageStateMachine, sqlStage, map, taskLifecycleListener, failureDetector, executor, optional, builder.buildOrThrow(), i);
        pipelinedStageExecution.initialize();
        return pipelinedStageExecution;
    }

    private PipelinedStageExecution(PipelinedStageStateMachine pipelinedStageStateMachine, SqlStage sqlStage, Map<PlanFragmentId, OutputBufferManager> map, TaskLifecycleListener taskLifecycleListener, FailureDetector failureDetector, Executor executor, Optional<int[]> optional, Map<PlanFragmentId, RemoteSourceNode> map2, int i) {
        this.stateMachine = (PipelinedStageStateMachine) Objects.requireNonNull(pipelinedStageStateMachine, "stateMachine is null");
        this.stage = (SqlStage) Objects.requireNonNull(sqlStage, "stage is null");
        this.outputBufferManagers = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "outputBufferManagers is null"));
        this.taskLifecycleListener = (TaskLifecycleListener) Objects.requireNonNull(taskLifecycleListener, "taskLifecycleListener is null");
        this.failureDetector = (FailureDetector) Objects.requireNonNull(failureDetector, "failureDetector is null");
        this.executor = (Executor) Objects.requireNonNull(executor, "executor is null");
        this.bucketToPartition = (Optional) Objects.requireNonNull(optional, "bucketToPartition is null");
        this.exchangeSources = ImmutableMap.copyOf((Map) Objects.requireNonNull(map2, "exchangeSources is null"));
        this.attempt = i;
    }

    private void initialize() {
        this.stateMachine.addStateChangeListener(state -> {
            if (state.canScheduleMoreTasks()) {
                return;
            }
            this.taskLifecycleListener.noMoreTasks(this.stage.getFragment().getId());
            updateSourceTasksOutputBuffers((v0) -> {
                v0.noMoreBuffers();
            });
        });
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public StageExecution.State getState() {
        return this.stateMachine.getState();
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public void addStateChangeListener(StateMachine.StateChangeListener<StageExecution.State> stateChangeListener) {
        this.stateMachine.addStateChangeListener(stateChangeListener);
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public void addCompletedDriverGroupsChangedListener(Consumer<Set<Lifespan>> consumer) {
        this.completedLifespansChangeListeners.addListener(consumer);
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void beginScheduling() {
        this.stateMachine.transitionToScheduling();
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void transitionToSchedulingSplits() {
        this.stateMachine.transitionToSchedulingSplits();
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void schedulingComplete() {
        if (this.stateMachine.transitionToScheduled()) {
            if (isFlushing()) {
                this.stateMachine.transitionToFlushing();
            }
            if (this.finishedTasks.containsAll(this.allTasks)) {
                this.stateMachine.transitionToFinished();
            }
            Iterator<PlanNodeId> it = this.stage.getFragment().getPartitionedSources().iterator();
            while (it.hasNext()) {
                schedulingComplete(it.next());
            }
        }
    }

    private synchronized boolean isFlushing() {
        return !this.flushingTasks.isEmpty() && this.allTasks.stream().allMatch(taskId -> {
            return this.finishedTasks.contains(taskId) || this.flushingTasks.contains(taskId);
        });
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void schedulingComplete(PlanNodeId planNodeId) {
        Iterator<RemoteTask> it = getAllTasks().iterator();
        while (it.hasNext()) {
            it.next().noMoreSplits(planNodeId);
        }
        this.completeSources.add(planNodeId);
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void cancel() {
        this.stateMachine.transitionToCanceled();
        getAllTasks().forEach((v0) -> {
            v0.cancel();
        });
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void abort() {
        this.stateMachine.transitionToAborted();
        getAllTasks().forEach((v0) -> {
            v0.abort();
        });
    }

    public synchronized void fail(Throwable th) {
        this.stateMachine.transitionToFailed(th);
        this.tasks.values().forEach((v0) -> {
            v0.abort();
        });
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void failTask(TaskId taskId, Throwable th) {
        ((RemoteTask) Objects.requireNonNull(this.tasks.get(Integer.valueOf(taskId.getPartitionId())), (Supplier<String>) () -> {
            return "task not found: " + taskId;
        })).fail(th);
        fail(th);
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized void failTaskRemotely(TaskId taskId, Throwable th) {
        ((RemoteTask) Objects.requireNonNull(this.tasks.get(Integer.valueOf(taskId.getPartitionId())), (Supplier<String>) () -> {
            return "task not found: " + taskId;
        })).failRemotely(th);
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public synchronized Optional<RemoteTask> scheduleTask(InternalNode internalNode, int i, Multimap<PlanNodeId, Split> multimap, Multimap<PlanNodeId, Lifespan> multimap2) {
        if (this.stateMachine.getState().isDone()) {
            return Optional.empty();
        }
        Preconditions.checkArgument(!this.tasks.containsKey(Integer.valueOf(i)), "A task for partition %s already exists", i);
        Optional<RemoteTask> createTask = this.stage.createTask(internalNode, i, this.attempt, this.bucketToPartition, this.outputBufferManagers.get(this.stage.getFragment().getId()).getOutputBuffers(), multimap, ImmutableMultimap.of(), ImmutableSet.of());
        if (createTask.isEmpty()) {
            return Optional.empty();
        }
        RemoteTask remoteTask = createTask.get();
        this.tasks.put(Integer.valueOf(i), remoteTask);
        ImmutableMultimap.Builder builder = ImmutableMultimap.builder();
        this.sourceTasks.forEach((planFragmentId, remoteTask2) -> {
            if (remoteTask2.getTaskStatus().getState() != TaskState.FINISHED) {
                builder.put(this.exchangeSources.get(planFragmentId).getId(), createExchangeSplit(remoteTask2, remoteTask));
            }
        });
        this.allTasks.add(remoteTask.getTaskId());
        remoteTask.addSplits(builder.build());
        Objects.requireNonNull(remoteTask);
        multimap2.forEach(remoteTask::noMoreSplits);
        Set<PlanNodeId> set = this.completeSources;
        Objects.requireNonNull(remoteTask);
        set.forEach(remoteTask::noMoreSplits);
        remoteTask.addStateChangeListener(this::updateTaskStatus);
        remoteTask.addStateChangeListener(this::updateCompletedDriverGroups);
        remoteTask.start();
        this.taskLifecycleListener.taskCreated(this.stage.getFragment().getId(), remoteTask);
        OutputBuffers.OutputBufferId outputBufferId = new OutputBuffers.OutputBufferId(remoteTask.getTaskId().getPartitionId());
        updateSourceTasksOutputBuffers(outputBufferManager -> {
            outputBufferManager.addOutputBuffer(outputBufferId);
        });
        return Optional.of(remoteTask);
    }

    private synchronized void updateTaskStatus(TaskStatus taskStatus) {
        StageExecution.State state = this.stateMachine.getState();
        if (state.isDone()) {
            return;
        }
        TaskState state2 = taskStatus.getState();
        switch (state2) {
            case FAILED:
                fail((RuntimeException) taskStatus.getFailures().stream().findFirst().map(this::rewriteTransportFailure).map((v0) -> {
                    return v0.toException();
                }).orElse(new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")));
                break;
            case CANCELED:
                fail(new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task is in the CANCELED state but stage is " + state));
                break;
            case ABORTED:
                fail(new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + state));
                break;
            case FLUSHING:
                this.flushingTasks.add(taskStatus.getTaskId());
                break;
            case FINISHED:
                this.finishedTasks.add(taskStatus.getTaskId());
                this.flushingTasks.remove(taskStatus.getTaskId());
                break;
        }
        if (state == StageExecution.State.SCHEDULED || state == StageExecution.State.RUNNING || state == StageExecution.State.FLUSHING) {
            if (state2 == TaskState.RUNNING) {
                this.stateMachine.transitionToRunning();
            }
            if (isFlushing()) {
                this.stateMachine.transitionToFlushing();
            }
            if (this.finishedTasks.containsAll(this.allTasks)) {
                this.stateMachine.transitionToFinished();
            }
        }
    }

    private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus) {
        Set<Lifespan> copyOf = ImmutableSet.copyOf(Sets.difference(taskStatus.getCompletedDriverGroups(), this.completedDriverGroups));
        if (copyOf.isEmpty()) {
            return;
        }
        this.completedLifespansChangeListeners.invoke(copyOf, this.executor);
        this.completedDriverGroups.addAll(copyOf);
    }

    private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) {
        return (executionFailureInfo.getRemoteHost() == null || this.failureDetector.getState(executionFailureInfo.getRemoteHost()) != FailureDetector.State.GONE) ? executionFailureInfo : new ExecutionFailureInfo(executionFailureInfo.getType(), executionFailureInfo.getMessage(), executionFailureInfo.getCause(), executionFailureInfo.getSuppressed(), executionFailureInfo.getStack(), executionFailureInfo.getErrorLocation(), StandardErrorCode.REMOTE_HOST_GONE.toErrorCode(), executionFailureInfo.getRemoteHost());
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public TaskLifecycleListener getTaskLifecycleListener() {
        return new TaskLifecycleListener() { // from class: io.trino.execution.scheduler.PipelinedStageExecution.1
            @Override // io.trino.execution.scheduler.TaskLifecycleListener
            public void taskCreated(PlanFragmentId planFragmentId, RemoteTask remoteTask) {
                PipelinedStageExecution.this.sourceTaskCreated(planFragmentId, remoteTask);
            }

            @Override // io.trino.execution.scheduler.TaskLifecycleListener
            public void noMoreTasks(PlanFragmentId planFragmentId) {
                PipelinedStageExecution.this.noMoreSourceTasks(planFragmentId);
            }
        };
    }

    private synchronized void sourceTaskCreated(PlanFragmentId planFragmentId, RemoteTask remoteTask) {
        Objects.requireNonNull(planFragmentId, "fragmentId is null");
        RemoteSourceNode remoteSourceNode = this.exchangeSources.get(planFragmentId);
        Preconditions.checkArgument(remoteSourceNode != null, "Unknown remote source %s. Known sources are %s", planFragmentId, this.exchangeSources.keySet());
        this.sourceTasks.put(planFragmentId, remoteTask);
        remoteTask.setOutputBuffers(this.outputBufferManagers.get(planFragmentId).getOutputBuffers());
        for (RemoteTask remoteTask2 : getAllTasks()) {
            remoteTask2.addSplits(ImmutableMultimap.of(remoteSourceNode.getId(), createExchangeSplit(remoteTask, remoteTask2)));
        }
    }

    private synchronized void noMoreSourceTasks(PlanFragmentId planFragmentId) {
        RemoteSourceNode remoteSourceNode = this.exchangeSources.get(planFragmentId);
        Preconditions.checkArgument(remoteSourceNode != null, "Unknown remote source %s. Known sources are %s", planFragmentId, this.exchangeSources.keySet());
        this.completeSourceFragments.add(planFragmentId);
        if (this.completeSourceFragments.containsAll(remoteSourceNode.getSourceFragmentIds())) {
            this.completeSources.add(remoteSourceNode.getId());
            Iterator<RemoteTask> it = getAllTasks().iterator();
            while (it.hasNext()) {
                it.next().noMoreSplits(remoteSourceNode.getId());
            }
        }
    }

    private synchronized void updateSourceTasksOutputBuffers(Consumer<OutputBufferManager> consumer) {
        for (PlanFragmentId planFragmentId : this.exchangeSources.keySet()) {
            OutputBufferManager outputBufferManager = this.outputBufferManagers.get(planFragmentId);
            consumer.accept(outputBufferManager);
            Iterator it = this.sourceTasks.get(planFragmentId).iterator();
            while (it.hasNext()) {
                ((RemoteTask) it.next()).setOutputBuffers(outputBufferManager.getOutputBuffers());
            }
        }
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public List<RemoteTask> getAllTasks() {
        return ImmutableList.copyOf(this.tasks.values());
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public List<TaskStatus> getTaskStatuses() {
        return (List) getAllTasks().stream().map((v0) -> {
            return v0.getTaskStatus();
        }).collect(ImmutableList.toImmutableList());
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public boolean isAnyTaskBlocked() {
        return getTaskStatuses().stream().anyMatch((v0) -> {
            return v0.isOutputBufferOverutilized();
        });
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public void recordGetSplitTime(long j) {
        this.stage.recordGetSplitTime(j);
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public StageId getStageId() {
        return this.stage.getStageId();
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public int getAttemptId() {
        return this.attempt;
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public PlanFragment getFragment() {
        return this.stage.getFragment();
    }

    @Override // io.trino.execution.scheduler.StageExecution
    public Optional<ExecutionFailureInfo> getFailureCause() {
        return this.stateMachine.getFailureCause();
    }

    private static Split createExchangeSplit(RemoteTask remoteTask, RemoteTask remoteTask2) {
        return new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, new RemoteSplit(new RemoteSplit.DirectExchangeInput(remoteTask.getTaskId(), HttpUriBuilder.uriBuilderFrom(remoteTask.getTaskStatus().getSelf()).appendPath("results").appendPath(String.valueOf(remoteTask2.getTaskId().getPartitionId())).build().toString())), Lifespan.taskWide());
    }
}
