package io.trino.execution;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slice;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.OutputBufferStateMachine;
import io.trino.execution.buffer.PagesSerde;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.execution.buffer.PartitionedOutputBuffer;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.SimpleLocalMemoryContext;
import io.trino.metadata.Split;
import io.trino.operator.DriverContext;
import io.trino.operator.DriverFactory;
import io.trino.operator.OperatorContext;
import io.trino.operator.SourceOperator;
import io.trino.operator.SourceOperatorFactory;
import io.trino.operator.TaskContext;
import io.trino.operator.output.TaskOutputOperator;
import io.trino.spi.HostAddress;
import io.trino.spi.Page;
import io.trino.spi.QueryId;
import io.trino.spi.block.Block;
import io.trino.spi.block.TestingBlockEncodingSerde;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.spi.connector.UpdatablePageSource;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.LocalExecutionPlanner;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingHandles;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.function.Supplier;
import org.openjdk.jol.info.ClassLayout;
import org.testng.Assert;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/execution/TestSqlTaskExecution.class */
public class TestSqlTaskExecution {
    private static final PipelinedOutputBuffers.OutputBufferId OUTPUT_BUFFER_ID = new PipelinedOutputBuffers.OutputBufferId(0);
    private static final Duration ASSERT_WAIT_TIMEOUT = new Duration(1.0d, TimeUnit.HOURS);
    public static final TaskId TASK_ID = new TaskId(new StageId("query", 0), 0, 0);

    /* loaded from: input_file:io/trino/execution/TestSqlTaskExecution$OutputBufferConsumer.class */
    private static class OutputBufferConsumer {
        private final OutputBuffer outputBuffer;
        private final PipelinedOutputBuffers.OutputBufferId outputBufferId;
        private int sequenceId;
        private int surplusPositions;
        private boolean bufferComplete;

        public OutputBufferConsumer(OutputBuffer outputBuffer, PipelinedOutputBuffers.OutputBufferId outputBufferId) {
            this.outputBuffer = outputBuffer;
            this.outputBufferId = outputBufferId;
        }

        public void consume(int i, Duration duration) throws ExecutionException, InterruptedException, TimeoutException {
            long nanoTime = System.nanoTime() + (duration.toMillis() * 1000000);
            this.surplusPositions -= i;
            while (this.surplusPositions < 0) {
                Assert.assertFalse(this.bufferComplete, "bufferComplete is set before enough positions are consumed");
                BufferResult bufferResult = (BufferResult) this.outputBuffer.get(this.outputBufferId, this.sequenceId, DataSize.of(1L, DataSize.Unit.MEGABYTE)).get(nanoTime - System.nanoTime(), TimeUnit.NANOSECONDS);
                this.bufferComplete = bufferResult.isBufferComplete();
                Iterator it = bufferResult.getSerializedPages().iterator();
                while (it.hasNext()) {
                    this.surplusPositions += PagesSerde.getSerializedPagePositionCount((Slice) it.next());
                }
                this.sequenceId += bufferResult.getSerializedPages().size();
            }
        }

        public void assertBufferComplete(Duration duration) throws InterruptedException, ExecutionException, TimeoutException {
            Assert.assertEquals(this.surplusPositions, 0);
            long nanoTime = System.nanoTime() + (duration.toMillis() * 1000000);
            while (!this.bufferComplete) {
                BufferResult bufferResult = (BufferResult) this.outputBuffer.get(this.outputBufferId, this.sequenceId, DataSize.of(1L, DataSize.Unit.MEGABYTE)).get(nanoTime - System.nanoTime(), TimeUnit.NANOSECONDS);
                this.bufferComplete = bufferResult.isBufferComplete();
                Iterator it = bufferResult.getSerializedPages().iterator();
                while (it.hasNext()) {
                    Assert.assertEquals(PagesSerde.getSerializedPagePositionCount((Slice) it.next()), 0);
                }
                this.sequenceId += bufferResult.getSerializedPages().size();
            }
        }

        public void abort() {
            this.outputBuffer.destroy(this.outputBufferId);
            Assert.assertEquals(this.outputBuffer.getInfo().getState(), BufferState.FINISHED);
        }
    }

    /* loaded from: input_file:io/trino/execution/TestSqlTaskExecution$Pauser.class */
    public static class Pauser {
        private volatile SettableFuture<Void> future = SettableFuture.create();

        public Pauser() {
            this.future.set((Object) null);
        }

        public void pause() {
            if (this.future.isDone()) {
                this.future = SettableFuture.create();
            }
        }

        public void resume() {
            if (this.future.isDone()) {
                return;
            }
            this.future.set((Object) null);
        }

        public void await() {
            try {
                this.future.get();
            } catch (Throwable th) {
                throw new RuntimeException(th);
            }
        }
    }

    /* loaded from: input_file:io/trino/execution/TestSqlTaskExecution$TestingScanOperatorFactory.class */
    public static class TestingScanOperatorFactory implements SourceOperatorFactory {
        private final int operatorId;
        private final PlanNodeId sourceId;
        private final Pauser pauser = new Pauser();
        private boolean overallNoMoreOperators;

        /* loaded from: input_file:io/trino/execution/TestSqlTaskExecution$TestingScanOperatorFactory$TestingScanOperator.class */
        public class TestingScanOperator implements SourceOperator {
            private final OperatorContext operatorContext;
            private final PlanNodeId planNodeId;
            private final SettableFuture<Void> blocked = SettableFuture.create();
            private TestingSplit split;
            private boolean finished;

            public TestingScanOperator(OperatorContext operatorContext, PlanNodeId planNodeId) {
                this.operatorContext = (OperatorContext) Objects.requireNonNull(operatorContext, "operatorContext is null");
                this.planNodeId = (PlanNodeId) Objects.requireNonNull(planNodeId, "planNodeId is null");
            }

            public OperatorContext getOperatorContext() {
                return this.operatorContext;
            }

            public PlanNodeId getSourceId() {
                return this.planNodeId;
            }

            public Supplier<Optional<UpdatablePageSource>> addSplit(Split split) {
                Objects.requireNonNull(split, "split is null");
                Preconditions.checkState(this.split == null, "Table scan split already set");
                if (this.finished) {
                    return Optional::empty;
                }
                this.split = (TestingSplit) split.getConnectorSplit();
                this.blocked.set((Object) null);
                return Optional::empty;
            }

            public void noMoreSplits() {
                if (this.split == null) {
                    finish();
                }
                this.blocked.set((Object) null);
            }

            public void close() {
                finish();
            }

            public void finish() {
                this.finished = true;
            }

            public boolean isFinished() {
                return this.finished;
            }

            public ListenableFuture<Void> isBlocked() {
                return this.blocked;
            }

            public boolean needsInput() {
                return false;
            }

            public void addInput(Page page) {
                throw new UnsupportedOperationException(getClass().getName() + " cannot take input");
            }

            public Page getOutput() {
                if (this.split == null) {
                    return null;
                }
                TestingScanOperatorFactory.this.pauser.await();
                Page page = new Page(new Block[]{BlockAssertions.createStringSequenceBlock(this.split.getBegin(), this.split.getEnd())});
                finish();
                return page;
            }
        }

        public TestingScanOperatorFactory(int i, PlanNodeId planNodeId) {
            this.operatorId = i;
            this.sourceId = (PlanNodeId) Objects.requireNonNull(planNodeId, "sourceId is null");
        }

        public PlanNodeId getSourceId() {
            return this.sourceId;
        }

        /* renamed from: createOperator, reason: merged with bridge method [inline-methods] */
        public SourceOperator m39createOperator(DriverContext driverContext) {
            Preconditions.checkState(!this.overallNoMoreOperators, "noMoreOperators() has been called");
            return new TestingScanOperator(driverContext.addOperatorContext(this.operatorId, this.sourceId, TestingScanOperator.class.getSimpleName()), this.sourceId);
        }

        public void noMoreOperators() {
            this.overallNoMoreOperators = true;
        }

        public boolean isOverallNoMoreOperators() {
            return this.overallNoMoreOperators;
        }

        public Pauser getPauser() {
            return this.pauser;
        }
    }

    /* loaded from: input_file:io/trino/execution/TestSqlTaskExecution$TestingSplit.class */
    public static class TestingSplit implements ConnectorSplit {
        private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(TestingSplit.class).instanceSize());
        private final int begin;
        private final int end;

        @JsonCreator
        public TestingSplit(@JsonProperty("begin") int i, @JsonProperty("end") int i2) {
            this.begin = i;
            this.end = i2;
        }

        public boolean isRemotelyAccessible() {
            return true;
        }

        public List<HostAddress> getAddresses() {
            return ImmutableList.of();
        }

        public Object getInfo() {
            return this;
        }

        public long getRetainedSizeInBytes() {
            return INSTANCE_SIZE;
        }

        public int getBegin() {
            return this.begin;
        }

        public int getEnd() {
            return this.end;
        }
    }

    @Test
    public void testSimple() throws Exception {
        ScheduledExecutorService newScheduledThreadPool = Executors.newScheduledThreadPool(10, Threads.threadsNamed("task-notification-%s"));
        ScheduledExecutorService newScheduledThreadPool2 = Executors.newScheduledThreadPool(2, Threads.threadsNamed("driver-yield-%s"));
        TaskExecutor taskExecutor = new TaskExecutor(5, 10, 3, 4, Ticker.systemTicker());
        taskExecutor.start();
        try {
            TaskStateMachine taskStateMachine = new TaskStateMachine(TASK_ID, newScheduledThreadPool);
            PartitionedOutputBuffer newTestingOutputBuffer = newTestingOutputBuffer(newScheduledThreadPool);
            OutputBufferConsumer outputBufferConsumer = new OutputBufferConsumer(newTestingOutputBuffer, OUTPUT_BUFFER_ID);
            TestingScanOperatorFactory testingScanOperatorFactory = new TestingScanOperatorFactory(0, TaskTestUtils.TABLE_SCAN_NODE_ID);
            SqlTaskExecution sqlTaskExecution = new SqlTaskExecution(taskStateMachine, newTestingTaskContext(newScheduledThreadPool, newScheduledThreadPool2, taskStateMachine), newTestingOutputBuffer, new LocalExecutionPlanner.LocalExecutionPlan(ImmutableList.of(new DriverFactory(0, true, true, ImmutableList.of(testingScanOperatorFactory, new TaskOutputOperator.TaskOutputOperatorFactory(1, TaskTestUtils.TABLE_SCAN_NODE_ID, newTestingOutputBuffer, Function.identity(), new PagesSerdeFactory(new TestingBlockEncodingSerde(), false))), OptionalInt.empty())), ImmutableList.of(TaskTestUtils.TABLE_SCAN_NODE_ID)), taskExecutor, TaskTestUtils.createTestSplitMonitor(), newScheduledThreadPool);
            sqlTaskExecution.start();
            Assert.assertEquals(taskStateMachine.getState(), TaskState.RUNNING);
            sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(newScheduledSplit(0, TaskTestUtils.TABLE_SCAN_NODE_ID, 100000, 123)), false)));
            outputBufferConsumer.consume(123, ASSERT_WAIT_TIMEOUT);
            testingScanOperatorFactory.getPauser().pause();
            sqlTaskExecution.addSplitAssignments(ImmutableList.of(new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, ImmutableSet.of(newScheduledSplit(1, TaskTestUtils.TABLE_SCAN_NODE_ID, 200000, 300), newScheduledSplit(2, TaskTestUtils.TABLE_SCAN_NODE_ID, 300000, 200)), true)));
            Objects.requireNonNull(testingScanOperatorFactory);
            waitUntilEquals(testingScanOperatorFactory::isOverallNoMoreOperators, true, ASSERT_WAIT_TIMEOUT);
            testingScanOperatorFactory.getPauser().resume();
            outputBufferConsumer.consume(500, ASSERT_WAIT_TIMEOUT);
            outputBufferConsumer.assertBufferComplete(ASSERT_WAIT_TIMEOUT);
            Assert.assertEquals(taskStateMachine.getStateChange(TaskState.RUNNING).get(10L, TimeUnit.SECONDS), TaskState.FLUSHING);
            outputBufferConsumer.abort();
            Assert.assertEquals(taskStateMachine.getStateChange(TaskState.FLUSHING).get(10L, TimeUnit.SECONDS), TaskState.FINISHED);
            taskExecutor.stop();
            newScheduledThreadPool.shutdownNow();
            newScheduledThreadPool2.shutdown();
        } catch (Throwable th) {
            taskExecutor.stop();
            newScheduledThreadPool.shutdownNow();
            newScheduledThreadPool2.shutdown();
            throw th;
        }
    }

    private TaskContext newTestingTaskContext(ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService2, TaskStateMachine taskStateMachine) {
        return new QueryContext(new QueryId("queryid"), DataSize.of(1L, DataSize.Unit.MEGABYTE), new MemoryPool(DataSize.of(1L, DataSize.Unit.GIGABYTE)), new TestingGcMonitor(), scheduledExecutorService, scheduledExecutorService2, DataSize.of(1L, DataSize.Unit.MEGABYTE), new SpillSpaceTracker(DataSize.of(1L, DataSize.Unit.GIGABYTE))).addTaskContext(taskStateMachine, SessionTestUtils.TEST_SESSION, () -> {
        }, false, false);
    }

    private PartitionedOutputBuffer newTestingOutputBuffer(ScheduledExecutorService scheduledExecutorService) {
        return new PartitionedOutputBuffer(TASK_ID.toString(), new OutputBufferStateMachine(TASK_ID, scheduledExecutorService), PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUTPUT_BUFFER_ID, 0).withNoMoreBufferIds(), DataSize.of(1L, DataSize.Unit.MEGABYTE), () -> {
            return new SimpleLocalMemoryContext(AggregatedMemoryContext.newSimpleAggregatedMemoryContext(), "test");
        }, scheduledExecutorService);
    }

    private <T> void waitUntilEquals(Supplier<T> supplier, T t, Duration duration) {
        long nanoTime = System.nanoTime() + (duration.toMillis() * 1000000);
        while (System.nanoTime() - nanoTime < 0) {
            if (t.equals(supplier.get())) {
                return;
            } else {
                try {
                    Thread.sleep(10L);
                } catch (InterruptedException e) {
                }
            }
        }
        Assert.assertEquals(supplier.get(), t);
    }

    private ScheduledSplit newScheduledSplit(int i, PlanNodeId planNodeId, int i2, int i3) {
        return new ScheduledSplit(i, planNodeId, new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(i2, i2 + i3)));
    }
}
