package io.trino.tests;

import com.google.common.collect.MoreCollectors;
import io.airlift.concurrent.Threads;
import io.trino.execution.QueryManager;
import io.trino.execution.QueryState;
import io.trino.server.BasicQueryInfo;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.assertions.Assert;
import io.trino.tests.tpch.TpchQueryRunner;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.SAME_THREAD)
/* loaded from: input_file:io/trino/tests/TestWorkerRestart.class */
public class TestWorkerRestart {
    private static final int TEST_ITERATIONS = 1;

    @Timeout(90)
    @RepeatedTest(TEST_ITERATIONS)
    public void testRestartBeforeQuery() throws Exception {
        DistributedQueryRunner build = TpchQueryRunner.builder().build();
        try {
            ExecutorService newCachedThreadPool = Executors.newCachedThreadPool(Threads.daemonThreadsNamed(getClass().getSimpleName() + "-%d"));
            try {
                try {
                    Assertions.assertThat(((Long) build.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()).longValue()).isEqualTo(60175L);
                    restartWorker(build);
                    newCachedThreadPool.submit(() -> {
                        return build.execute("SELECT count(*) FROM tpch.sf1.lineitem -- " + String.valueOf(UUID.randomUUID()));
                    }).get();
                    Assertions.assertThat(((Long) build.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()).longValue()).isEqualTo(60175L);
                    cancelQueries(build);
                    if (newCachedThreadPool != null) {
                        newCachedThreadPool.close();
                    }
                    if (build != null) {
                        build.close();
                    }
                } finally {
                }
            } catch (Throwable th) {
                cancelQueries(build);
                throw th;
            }
        } catch (Throwable th2) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th3) {
                    th2.addSuppressed(th3);
                }
            }
            throw th2;
        }
    }

    @Timeout(90)
    @RepeatedTest(TEST_ITERATIONS)
    public void testRestartDuringQuery() throws Exception {
        DistributedQueryRunner build = TpchQueryRunner.builder().build();
        try {
            ExecutorService newCachedThreadPool = Executors.newCachedThreadPool(Threads.daemonThreadsNamed(getClass().getSimpleName() + "-%d"));
            try {
                try {
                    Assertions.assertThat(((Long) build.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()).longValue()).isEqualTo(60175L);
                    String str = "SELECT count(*) FROM tpch.sf1000000000.lineitem -- " + String.valueOf(UUID.randomUUID());
                    Future submit = newCachedThreadPool.submit(() -> {
                        return build.execute(str);
                    });
                    waitForQueryStart(build, str);
                    restartWorker(build);
                    Objects.requireNonNull(submit);
                    Assertions.assertThatThrownBy(submit::get).isInstanceOf(ExecutionException.class).cause().hasMessageFindingMatch("^Expected response code from \\S+ to be 200, but was 500|Error fetching \\S+: Expected response code to be 200, but was 500");
                    Assertions.assertThat(((Long) build.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()).longValue()).isEqualTo(60175L);
                    cancelQueries(build);
                    if (newCachedThreadPool != null) {
                        newCachedThreadPool.close();
                    }
                    if (build != null) {
                        build.close();
                    }
                } finally {
                }
            } catch (Throwable th) {
                cancelQueries(build);
                throw th;
            }
        } catch (Throwable th2) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th3) {
                    th2.addSuppressed(th3);
                }
            }
            throw th2;
        }
    }

    @Timeout(90)
    @RepeatedTest(TEST_ITERATIONS)
    public void testStartDuringQuery() throws Exception {
        DistributedQueryRunner build = TpchQueryRunner.builder().build();
        try {
            ExecutorService newCachedThreadPool = Executors.newCachedThreadPool(Threads.daemonThreadsNamed(getClass().getSimpleName() + "-%d"));
            try {
                try {
                    Assertions.assertThat(((Long) build.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()).longValue()).isEqualTo(60175L);
                    TestingTrinoServer testingTrinoServer = (TestingTrinoServer) build.getServers().stream().filter(testingTrinoServer2 -> {
                        return !testingTrinoServer2.isCoordinator();
                    }).findFirst().orElseThrow();
                    testingTrinoServer.close();
                    Future submit = newCachedThreadPool.submit(() -> {
                        return build.execute("SELECT count(*) FROM tpch.tiny.lineitem -- " + String.valueOf(UUID.randomUUID()));
                    });
                    build.restartWorker(testingTrinoServer);
                    submit.get();
                    Assertions.assertThat(((Long) build.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()).longValue()).isEqualTo(60175L);
                    cancelQueries(build);
                    if (newCachedThreadPool != null) {
                        newCachedThreadPool.close();
                    }
                    if (build != null) {
                        build.close();
                    }
                } catch (Throwable th) {
                    cancelQueries(build);
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th2) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th3) {
                    th2.addSuppressed(th3);
                }
            }
            throw th2;
        }
    }

    private static void waitForQueryStart(DistributedQueryRunner distributedQueryRunner, String str) {
        Assert.assertEventually(() -> {
            Assertions.assertThat(((BasicQueryInfo) distributedQueryRunner.getCoordinator().getQueryManager().getQueries().stream().filter(basicQueryInfo -> {
                return basicQueryInfo.getQuery().equals(str);
            }).collect(MoreCollectors.onlyElement())).getState()).isEqualTo(QueryState.RUNNING);
        });
    }

    private static void restartWorker(DistributedQueryRunner distributedQueryRunner) throws Exception {
        distributedQueryRunner.restartWorker((TestingTrinoServer) distributedQueryRunner.getServers().stream().filter(testingTrinoServer -> {
            return !testingTrinoServer.isCoordinator();
        }).findFirst().orElseThrow());
    }

    private static void cancelQueries(DistributedQueryRunner distributedQueryRunner) {
        QueryManager queryManager = distributedQueryRunner.getCoordinator().getQueryManager();
        Stream map = queryManager.getQueries().stream().map((v0) -> {
            return v0.getQueryId();
        });
        Objects.requireNonNull(queryManager);
        map.forEach(queryManager::cancelQuery);
    }
}
