package io.trino.execution;

import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slice;
import io.airlift.stats.CounterStat;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.SessionTestUtils;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.DynamicFiltersCollector;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.PagesSerde;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.metadata.ExchangeHandleResolver;
import io.trino.operator.TaskContext;
import io.trino.spi.QueryId;
import io.trino.spi.predicate.Domain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.VarcharType;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.testing.TestingSession;
import java.net.URI;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/execution/TestSqlTask.class */
public class TestSqlTask {
    public static final PipelinedOutputBuffers.OutputBufferId OUT = new PipelinedOutputBuffers.OutputBufferId(0);
    private final ScheduledExecutorService taskNotificationExecutor;
    private final ScheduledExecutorService driverYieldExecutor;
    private final SqlTaskExecutionFactory sqlTaskExecutionFactory;
    private final AtomicInteger nextTaskId = new AtomicInteger();
    private final TaskExecutor taskExecutor = new TaskExecutor(8, 16, 3, 4, Ticker.systemTicker());

    public TestSqlTask() {
        this.taskExecutor.start();
        this.taskNotificationExecutor = Executors.newScheduledThreadPool(10, Threads.threadsNamed("task-notification-%s"));
        this.driverYieldExecutor = Executors.newScheduledThreadPool(2, Threads.threadsNamed("driver-yield-%s"));
        this.sqlTaskExecutionFactory = new SqlTaskExecutionFactory(this.taskNotificationExecutor, this.taskExecutor, TaskTestUtils.createTestingPlanner(), TaskTestUtils.createTestSplitMonitor(), new TaskManagerConfig());
    }

    @AfterClass(alwaysRun = true)
    public void destroy() {
        this.taskExecutor.stop();
        this.taskNotificationExecutor.shutdownNow();
        this.driverYieldExecutor.shutdown();
    }

    @Test(timeOut = 30000)
    public void testEmptyQuery() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        TaskInfo updateTask = createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withNoMoreBufferIds(), ImmutableMap.of());
        Assert.assertEquals(updateTask.getTaskStatus().getState(), TaskState.RUNNING);
        Assert.assertEquals(updateTask.getTaskStatus().getVersion(), 0L);
        TaskInfo taskInfo = createInitialTask.getTaskInfo();
        Assert.assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING);
        Assert.assertEquals(taskInfo.getTaskStatus().getVersion(), 0L);
        Assert.assertEquals(createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withNoMoreBufferIds(), ImmutableMap.of()).getTaskStatus().getState(), TaskState.FINISHED);
        Assert.assertEquals(((TaskInfo) createInitialTask.getTaskInfo(0L).get()).getTaskStatus().getState(), TaskState.FINISHED);
    }

    @Test(timeOut = 30000)
    public void testSimpleQuery() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        Assert.assertEquals(createInitialTask.getTaskStatus().getState(), TaskState.RUNNING);
        Assert.assertEquals(createInitialTask.getTaskStatus().getVersion(), 0L);
        createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(TaskTestUtils.SPLIT), true)), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of());
        TaskInfo taskInfo = (TaskInfo) createInitialTask.getTaskInfo(0L).get();
        Assert.assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING);
        Assert.assertEquals(taskInfo.getTaskStatus().getVersion(), 1L);
        Assert.assertTrue(createInitialTask.getTaskInfo(0L).isDone());
        BufferResult bufferResult = (BufferResult) createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE)).get();
        Assert.assertFalse(bufferResult.isBufferComplete());
        Assert.assertEquals(bufferResult.getSerializedPages().size(), 1);
        Assert.assertEquals(PagesSerde.getSerializedPagePositionCount((Slice) bufferResult.getSerializedPages().get(0)), 1);
        boolean z = true;
        while (z) {
            bufferResult = (BufferResult) createInitialTask.getTaskResults(OUT, bufferResult.getToken() + bufferResult.getSerializedPages().size(), DataSize.of(1L, DataSize.Unit.MEGABYTE)).get();
            z = !bufferResult.isBufferComplete();
        }
        Assert.assertEquals(bufferResult.getSerializedPages().size(), 0);
        TaskInfo destroyTaskResults = createInitialTask.destroyTaskResults(OUT);
        Assert.assertEquals(destroyTaskResults.getOutputBuffers().getState(), BufferState.FINISHED);
        Assert.assertEquals(((TaskInfo) createInitialTask.getTaskInfo(destroyTaskResults.getTaskStatus().getVersion()).get()).getTaskStatus().getState(), TaskState.FINISHED);
        Assert.assertTrue(createInitialTask.getTaskInfo(100L).isDone());
        Assert.assertEquals(createInitialTask.getTaskInfo().getTaskStatus().getState(), TaskState.FINISHED);
    }

    @Test
    public void testCancel() {
        SqlTask createInitialTask = createInitialTask();
        TaskInfo updateTask = createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of());
        Assert.assertEquals(updateTask.getTaskStatus().getState(), TaskState.RUNNING);
        Assert.assertNull(updateTask.getStats().getEndTime());
        TaskInfo taskInfo = createInitialTask.getTaskInfo();
        Assert.assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING);
        Assert.assertNull(taskInfo.getStats().getEndTime());
        TaskInfo cancel = createInitialTask.cancel();
        Assert.assertEquals(cancel.getTaskStatus().getState(), TaskState.CANCELED);
        Assert.assertNotNull(cancel.getStats().getEndTime());
        TaskInfo taskInfo2 = createInitialTask.getTaskInfo();
        Assert.assertEquals(taskInfo2.getTaskStatus().getState(), TaskState.CANCELED);
        Assert.assertNotNull(taskInfo2.getStats().getEndTime());
    }

    @Test(timeOut = 30000)
    public void testAbort() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        Assert.assertEquals(createInitialTask.getTaskStatus().getState(), TaskState.RUNNING);
        Assert.assertEquals(createInitialTask.getTaskStatus().getVersion(), 0L);
        createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(TaskTestUtils.SPLIT), true)), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of());
        TaskInfo taskInfo = (TaskInfo) createInitialTask.getTaskInfo(0L).get();
        Assert.assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING);
        Assert.assertEquals(taskInfo.getTaskStatus().getVersion(), 1L);
        createInitialTask.destroyTaskResults(OUT);
        Assert.assertEquals(((TaskInfo) createInitialTask.getTaskInfo(taskInfo.getTaskStatus().getVersion()).get()).getTaskStatus().getState(), TaskState.FINISHED);
        Assert.assertEquals(createInitialTask.getTaskInfo().getTaskStatus().getState(), TaskState.FINISHED);
    }

    @Test
    public void testBufferCloseOnFinish() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        PipelinedOutputBuffers withNoMoreBufferIds = PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds();
        TaskTestUtils.updateTask(createInitialTask, TaskTestUtils.EMPTY_SPLIT_ASSIGNMENTS, withNoMoreBufferIds);
        ListenableFuture taskResults = createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE));
        Assert.assertFalse(taskResults.isDone());
        TaskTestUtils.updateTask(createInitialTask, ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), withNoMoreBufferIds);
        createInitialTask.destroyTaskResults(OUT);
        taskResults.get(1L, TimeUnit.SECONDS);
        ListenableFuture taskResults2 = createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE));
        Assert.assertTrue(taskResults2.isDone());
        Assert.assertTrue(((BufferResult) taskResults2.get()).isBufferComplete());
    }

    @Test
    public void testBufferCloseOnCancel() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        TaskTestUtils.updateTask(createInitialTask, TaskTestUtils.EMPTY_SPLIT_ASSIGNMENTS, PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
        ListenableFuture taskResults = createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE));
        Assert.assertFalse(taskResults.isDone());
        createInitialTask.cancel();
        Assert.assertEquals(createInitialTask.getTaskInfo().getTaskStatus().getState(), TaskState.CANCELED);
        taskResults.get(1L, TimeUnit.SECONDS);
        ListenableFuture taskResults2 = createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE));
        Assert.assertTrue(taskResults2.isDone());
        Assert.assertTrue(((BufferResult) taskResults2.get()).isBufferComplete());
    }

    @Test(timeOut = 30000)
    public void testBufferNotCloseOnFail() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        TaskTestUtils.updateTask(createInitialTask, TaskTestUtils.EMPTY_SPLIT_ASSIGNMENTS, PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
        ListenableFuture taskResults = createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE));
        Assert.assertFalse(taskResults.isDone());
        long version = createInitialTask.getTaskInfo().getTaskStatus().getVersion();
        createInitialTask.failed(new Exception("test"));
        Assert.assertEquals(((TaskInfo) createInitialTask.getTaskInfo(version).get()).getTaskStatus().getState(), TaskState.FAILED);
        Assertions.assertThatThrownBy(() -> {
            taskResults.get(1L, TimeUnit.SECONDS);
        }).isInstanceOf(TimeoutException.class).hasMessageContaining("Waited 1 seconds");
        Assert.assertFalse(createInitialTask.getTaskResults(OUT, 0L, DataSize.of(1L, DataSize.Unit.MEGABYTE)).isDone());
    }

    @Test(timeOut = 30000)
    public void testDynamicFilters() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT_WITH_DYNAMIC_FILTER_SOURCE), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(TaskTestUtils.SPLIT), false)), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of());
        Assert.assertEquals(createInitialTask.getTaskStatus().getDynamicFiltersVersion(), 0L);
        TaskContext taskContextByTaskId = createInitialTask.getQueryContext().getTaskContextByTaskId(createInitialTask.getTaskId());
        ListenableFuture taskStatus = createInitialTask.getTaskStatus(0L);
        Assert.assertFalse(taskStatus.isDone());
        taskContextByTaskId.updateDomains(ImmutableMap.of(TaskTestUtils.DYNAMIC_FILTER_SOURCE_ID, Domain.none(BigintType.BIGINT)));
        Assert.assertEquals(createInitialTask.getTaskStatus().getVersion(), 1L);
        Assert.assertEquals(createInitialTask.getTaskStatus().getDynamicFiltersVersion(), 1L);
        taskStatus.get();
    }

    @Test(timeOut = 30000)
    public void testDynamicFilterFetchAfterTaskDone() throws Exception {
        SqlTask createInitialTask = createInitialTask();
        PipelinedOutputBuffers withNoMoreBufferIds = PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds();
        createInitialTask.updateTask(SessionTestUtils.TEST_SESSION, Optional.of(TaskTestUtils.PLAN_FRAGMENT_WITH_DYNAMIC_FILTER_SOURCE), ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(), false)), withNoMoreBufferIds, ImmutableMap.of());
        Assert.assertEquals(createInitialTask.getTaskStatus().getDynamicFiltersVersion(), 0L);
        TaskTestUtils.updateTask(createInitialTask, ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), withNoMoreBufferIds);
        TaskInfo destroyTaskResults = createInitialTask.destroyTaskResults(OUT);
        Assert.assertEquals(destroyTaskResults.getOutputBuffers().getState(), BufferState.FINISHED);
        io.trino.testing.assertions.Assert.assertEventually(new Duration(10.0d, TimeUnit.SECONDS), () -> {
            TaskStatus taskStatus = (TaskStatus) createInitialTask.getTaskStatus(destroyTaskResults.getTaskStatus().getVersion()).get();
            Assert.assertEquals(taskStatus.getState(), TaskState.FINISHED);
            Assert.assertEquals(taskStatus.getDynamicFiltersVersion(), 1L);
        });
        DynamicFiltersCollector.VersionedDynamicFilterDomains acknowledgeAndGetNewDynamicFilterDomains = createInitialTask.acknowledgeAndGetNewDynamicFilterDomains(0L);
        Assert.assertEquals(acknowledgeAndGetNewDynamicFilterDomains.getVersion(), 1L);
        Assert.assertEquals(acknowledgeAndGetNewDynamicFilterDomains.getDynamicFilterDomains(), ImmutableMap.of(TaskTestUtils.DYNAMIC_FILTER_SOURCE_ID, Domain.none(VarcharType.VARCHAR)));
    }

    private SqlTask createInitialTask() {
        TaskId taskId = new TaskId(new StageId("query", 0), this.nextTaskId.incrementAndGet(), 0);
        URI create = URI.create("fake://task/" + taskId);
        QueryContext queryContext = new QueryContext(new QueryId("query"), DataSize.of(1L, DataSize.Unit.MEGABYTE), new MemoryPool(DataSize.of(1L, DataSize.Unit.GIGABYTE)), new TestingGcMonitor(), this.taskNotificationExecutor, this.driverYieldExecutor, DataSize.of(1L, DataSize.Unit.MEGABYTE), new SpillSpaceTracker(DataSize.of(1L, DataSize.Unit.GIGABYTE)));
        queryContext.addTaskContext(new TaskStateMachine(taskId, this.taskNotificationExecutor), TestingSession.testSessionBuilder().build(), () -> {
        }, false, false);
        return SqlTask.createSqlTask(taskId, create, "fake", queryContext, this.sqlTaskExecutionFactory, this.taskNotificationExecutor, sqlTask -> {
        }, DataSize.of(32L, DataSize.Unit.MEGABYTE), DataSize.of(200L, DataSize.Unit.MEGABYTE), new ExchangeManagerRegistry(new ExchangeHandleResolver()), new CounterStat());
    }
}
