package io.trino.execution.executor.dedicated;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.testing.TestingTicker;
import io.airlift.tracing.Tracing;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.trino.execution.SplitRunner;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.executor.TaskHandle;
import io.trino.execution.executor.scheduler.FairScheduler;
import io.trino.version.EmbedVersion;
import java.util.List;
import java.util.OptionalInt;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

/* loaded from: input_file:io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.class */
public class TestThreadPerDriverTaskExecutor {

    /* loaded from: input_file:io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor$TestFuture.class */
    private static class TestFuture extends AbstractFuture<Void> {
        private final CountDownLatch listenerAdded = new CountDownLatch(1);

        private TestFuture() {
        }

        public void addListener(Runnable runnable, Executor executor) {
            super.addListener(runnable, executor);
            this.listenerAdded.countDown();
        }

        public boolean set(Void r4) {
            return super.set(r4);
        }

        public void awaitListenerAdded() throws InterruptedException {
            this.listenerAdded.await();
        }
    }

    /* loaded from: input_file:io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor$TestingSplitRunner.class */
    private static class TestingSplitRunner implements SplitRunner {
        private final List<Function<Duration, ListenableFuture<Void>>> invocations;
        private int invocation;
        private volatile boolean finished;
        private volatile Thread runnerThread;

        public TestingSplitRunner(List<Function<Duration, ListenableFuture<Void>>> list) {
            this.invocations = list;
        }

        public final int getPipelineId() {
            return 0;
        }

        public final Span getPipelineSpan() {
            return Span.getInvalid();
        }

        public final boolean isFinished() {
            return this.finished;
        }

        public final ListenableFuture<Void> processFor(Duration duration) {
            this.runnerThread = Thread.currentThread();
            try {
                ListenableFuture<Void> apply = this.invocations.get(this.invocation).apply(duration);
                this.invocation++;
                if (this.invocation == this.invocations.size()) {
                    this.finished = true;
                }
                return apply;
            } finally {
                this.runnerThread = null;
            }
        }

        public final String getInfo() {
            return "";
        }

        public final void close() {
            this.finished = true;
            Thread thread = this.runnerThread;
            if (thread != null) {
                thread.interrupt();
            }
        }
    }

    @Timeout(10)
    @Test
    public void testCancellationWhileProcessing() throws ExecutionException, InterruptedException {
        ThreadPerDriverTaskExecutor threadPerDriverTaskExecutor = new ThreadPerDriverTaskExecutor(new TaskManagerConfig(), Tracing.noopTracer(), EmbedVersion.testingVersionEmbedder());
        threadPerDriverTaskExecutor.start();
        try {
            TaskHandle addTask = threadPerDriverTaskExecutor.addTask(new TaskId(new StageId("query", 1), 1, 1), () -> {
                return 0.0d;
            }, 10, new Duration(1.0d, TimeUnit.MILLISECONDS), OptionalInt.empty());
            CountDownLatch countDownLatch = new CountDownLatch(1);
            TestingSplitRunner testingSplitRunner = new TestingSplitRunner(ImmutableList.of(duration -> {
                countDownLatch.countDown();
                try {
                    Thread.currentThread().join();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                return Futures.immediateVoidFuture();
            }));
            ListenableFuture listenableFuture = (ListenableFuture) threadPerDriverTaskExecutor.enqueueSplits(addTask, false, ImmutableList.of(testingSplitRunner)).get(0);
            countDownLatch.await();
            threadPerDriverTaskExecutor.removeTask(addTask);
            listenableFuture.get();
            Assertions.assertThat(testingSplitRunner.isFinished()).isTrue();
            threadPerDriverTaskExecutor.stop();
        } catch (Throwable th) {
            threadPerDriverTaskExecutor.stop();
            throw th;
        }
    }

    @Timeout(10)
    @Test
    public void testBlocking() throws ExecutionException, InterruptedException {
        ThreadPerDriverTaskExecutor threadPerDriverTaskExecutor = new ThreadPerDriverTaskExecutor(new TaskManagerConfig(), Tracing.noopTracer(), EmbedVersion.testingVersionEmbedder());
        threadPerDriverTaskExecutor.start();
        try {
            TaskHandle addTask = threadPerDriverTaskExecutor.addTask(new TaskId(new StageId("query", 1), 1, 1), () -> {
                return 0.0d;
            }, 10, new Duration(1.0d, TimeUnit.MILLISECONDS), OptionalInt.empty());
            TestFuture testFuture = new TestFuture();
            TestingSplitRunner testingSplitRunner = new TestingSplitRunner(ImmutableList.of(duration -> {
                return testFuture;
            }, duration2 -> {
                return Futures.immediateVoidFuture();
            }));
            ListenableFuture listenableFuture = (ListenableFuture) threadPerDriverTaskExecutor.enqueueSplits(addTask, false, ImmutableList.of(testingSplitRunner)).get(0);
            testFuture.awaitListenerAdded();
            testFuture.set((Void) null);
            listenableFuture.get();
            Assertions.assertThat(testingSplitRunner.isFinished()).isTrue();
            threadPerDriverTaskExecutor.stop();
        } catch (Throwable th) {
            threadPerDriverTaskExecutor.stop();
            throw th;
        }
    }

    @Timeout(10)
    @Test
    public void testYielding() throws ExecutionException, InterruptedException {
        TestingTicker testingTicker = new TestingTicker();
        ThreadPerDriverTaskExecutor threadPerDriverTaskExecutor = new ThreadPerDriverTaskExecutor(Tracing.noopTracer(), EmbedVersion.testingVersionEmbedder(), new FairScheduler(1, "Runner-%d", testingTicker));
        threadPerDriverTaskExecutor.start();
        try {
            TaskHandle addTask = threadPerDriverTaskExecutor.addTask(new TaskId(new StageId("query", 1), 1, 1), () -> {
                return 0.0d;
            }, 10, new Duration(1.0d, TimeUnit.MILLISECONDS), OptionalInt.empty());
            Phaser phaser = new Phaser(2);
            TestingSplitRunner testingSplitRunner = new TestingSplitRunner(ImmutableList.of(duration -> {
                phaser.arriveAndAwaitAdvance();
                phaser.arriveAndAwaitAdvance();
                return Futures.immediateVoidFuture();
            }, duration2 -> {
                phaser.arriveAndAwaitAdvance();
                return Futures.immediateVoidFuture();
            }));
            ListenableFuture listenableFuture = (ListenableFuture) threadPerDriverTaskExecutor.enqueueSplits(addTask, false, ImmutableList.of(testingSplitRunner)).get(0);
            phaser.arriveAndAwaitAdvance();
            testingTicker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS);
            phaser.arriveAndAwaitAdvance();
            Assertions.assertThat(phaser.arriveAndAwaitAdvance()).isEqualTo(3);
            listenableFuture.get();
            Assertions.assertThat(testingSplitRunner.isFinished()).isTrue();
            threadPerDriverTaskExecutor.stop();
        } catch (Throwable th) {
            threadPerDriverTaskExecutor.stop();
            throw th;
        }
    }
}
