package io.trino.memory;

import com.google.common.collect.ImmutableMap;
import io.airlift.concurrent.Threads;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.trino.SessionTestUtils;
import io.trino.execution.TaskId;
import io.trino.execution.TaskStateMachine;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.spi.QueryId;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.LocalQueryRunner;
import java.util.Map;
import java.util.OptionalInt;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/memory/TestQueryContext.class */
public class TestQueryContext {
    private static final ScheduledExecutorService TEST_EXECUTOR = Executors.newScheduledThreadPool(1, Threads.threadsNamed("test-executor-%s"));

    @AfterClass(alwaysRun = true)
    public void tearDown() {
        TEST_EXECUTOR.shutdownNow();
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider
    public Object[][] testSetMemoryPoolOptions() {
        return new Object[]{new Object[]{false}, new Object[]{true}};
    }

    @Test(dataProvider = "testSetMemoryPoolOptions")
    public void testSetMemoryPool(boolean z) {
        QueryId queryId = new QueryId("second");
        MemoryPool memoryPool = new MemoryPool(LocalMemoryManager.RESERVED_POOL, DataSize.ofBytes(10L));
        long maxBytes = memoryPool.getMaxBytes() - 1;
        if (z) {
            Assert.assertTrue(memoryPool.reserve(queryId, "test", maxBytes).isDone());
        }
        LocalQueryRunner create = LocalQueryRunner.create(SessionTestUtils.TEST_SESSION);
        try {
            QueryContext queryContext = new QueryContext(new QueryId("query"), DataSize.ofBytes(10L), DataSize.ofBytes(20L), new MemoryPool(LocalMemoryManager.GENERAL_POOL, DataSize.ofBytes(10L)), new TestingGcMonitor(), create.getExecutor(), create.getScheduler(), DataSize.ofBytes(0L), new SpillSpaceTracker(DataSize.ofBytes(0L)));
            queryContext.getQueryMemoryContext().initializeLocalMemoryContexts("test");
            LocalMemoryContext localUserMemoryContext = queryContext.getQueryMemoryContext().localUserMemoryContext();
            LocalMemoryContext localRevocableMemoryContext = queryContext.getQueryMemoryContext().localRevocableMemoryContext();
            Assert.assertTrue(localUserMemoryContext.setBytes(3L).isDone());
            Assert.assertTrue(localRevocableMemoryContext.setBytes(5L).isDone());
            queryContext.setMemoryPool(memoryPool);
            if (z) {
                memoryPool.free(queryId, "test", maxBytes);
            }
            localUserMemoryContext.close();
            localRevocableMemoryContext.close();
            if (create != null) {
                create.close();
            }
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testMoveTaggedAllocations() {
        MemoryPool memoryPool = new MemoryPool(LocalMemoryManager.GENERAL_POOL, DataSize.ofBytes(10000L));
        MemoryPool memoryPool2 = new MemoryPool(LocalMemoryManager.RESERVED_POOL, DataSize.ofBytes(10000L));
        QueryId queryId = new QueryId("query");
        QueryContext createQueryContext = createQueryContext(queryId, memoryPool);
        LocalMemoryContext newLocalMemoryContext = createQueryContext.addTaskContext(new TaskStateMachine(TaskId.valueOf("task-id"), TEST_EXECUTOR), SessionTestUtils.TEST_SESSION, () -> {
        }, false, false, OptionalInt.empty()).addPipelineContext(0, false, false, false).addDriverContext().addOperatorContext(0, new PlanNodeId("test"), "test").aggregateUserMemoryContext().newLocalMemoryContext("test_context");
        newLocalMemoryContext.setBytes(1000L);
        Assert.assertEquals((Map) memoryPool.getTaggedMemoryAllocations().get(queryId), ImmutableMap.of("test_context", 1000L));
        createQueryContext.setMemoryPool(memoryPool2);
        Assert.assertNull(memoryPool.getTaggedMemoryAllocations().get(queryId));
        Assert.assertEquals((Map) memoryPool2.getTaggedMemoryAllocations().get(queryId), ImmutableMap.of("test_context", 1000L));
        Assert.assertEquals(memoryPool.getFreeBytes(), 10000L);
        Assert.assertEquals(memoryPool2.getFreeBytes(), 9000L);
        newLocalMemoryContext.close();
        Assert.assertEquals(memoryPool.getFreeBytes(), 10000L);
        Assert.assertEquals(memoryPool2.getFreeBytes(), 10000L);
    }

    private static QueryContext createQueryContext(QueryId queryId, MemoryPool memoryPool) {
        return new QueryContext(queryId, DataSize.ofBytes(10000L), DataSize.ofBytes(10000L), memoryPool, new TestingGcMonitor(), TEST_EXECUTOR, TEST_EXECUTOR, DataSize.ofBytes(0L), new SpillSpaceTracker(DataSize.ofBytes(0L)));
    }
}
