package io.trino.client;

import com.google.common.collect.ImmutableList;
import com.google.common.net.MediaType;
import io.airlift.json.JsonCodec;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.time.ZoneId;
import java.util.List;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import okhttp3.OkHttpClient;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.SocketPolicy;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(TestInstance.Lifecycle.PER_METHOD)
/* loaded from: input_file:io/trino/client/TestRetry.class */
public class TestRetry {
    private MockWebServer server;
    private static final JsonCodec<QueryResults> QUERY_RESULTS_CODEC = JsonCodec.jsonCodec(QueryResults.class);

    @BeforeEach
    public void setup() throws Exception {
        this.server = new MockWebServer();
        this.server.start();
    }

    @AfterEach
    public void teardown() throws IOException {
        this.server.close();
        this.server = null;
    }

    @Test
    public void testRetryOnInitial() {
        Duration ofMillis = Duration.ofMillis(100L);
        OkHttpClient build = new OkHttpClient.Builder().connectTimeout(ofMillis).readTimeout(ofMillis).writeTimeout(ofMillis).callTimeout(ofMillis).build();
        ClientSession build2 = ClientSession.builder().server(URI.create("http://" + this.server.getHostName() + ":" + this.server.getPort())).timeZone(ZoneId.of("UTC")).clientRequestTimeout(io.airlift.units.Duration.valueOf("2s")).build();
        this.server.enqueue(statusAndBody(200, newQueryResults("RUNNING")).setSocketPolicy(SocketPolicy.STALL_SOCKET_AT_START));
        this.server.enqueue(statusAndBody(200, newQueryResults("FINISHED")));
        StatementClient newStatementClient = StatementClientFactory.newStatementClient(build, build2, "SELECT 1", Optional.empty());
        do {
            try {
            } catch (Throwable th) {
                if (newStatementClient != null) {
                    try {
                        newStatementClient.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } while (newStatementClient.advance());
        Assertions.assertThat(newStatementClient.isFinished()).isTrue();
        if (newStatementClient != null) {
            newStatementClient.close();
        }
        Assertions.assertThat(this.server.getRequestCount()).isEqualTo(2);
    }

    @Test
    public void testRetryOnBrokenStream() {
        Duration ofMillis = Duration.ofMillis(100L);
        OkHttpClient build = new OkHttpClient.Builder().connectTimeout(ofMillis).readTimeout(ofMillis).writeTimeout(ofMillis).callTimeout(ofMillis).build();
        ClientSession build2 = ClientSession.builder().server(URI.create("http://" + this.server.getHostName() + ":" + this.server.getPort())).timeZone(ZoneId.of("UTC")).clientRequestTimeout(io.airlift.units.Duration.valueOf("2s")).build();
        this.server.enqueue(statusAndBody(200, newQueryResults("RUNNING")));
        this.server.enqueue(statusAndBody(200, newQueryResults("FINISHED")).setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY));
        this.server.enqueue(statusAndBody(200, newQueryResults("FINISHED")));
        StatementClient newStatementClient = StatementClientFactory.newStatementClient(build, build2, "SELECT 1", Optional.empty());
        do {
            try {
            } catch (Throwable th) {
                if (newStatementClient != null) {
                    try {
                        newStatementClient.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } while (newStatementClient.advance());
        Assertions.assertThat(newStatementClient.isFinished()).isTrue();
        if (newStatementClient != null) {
            newStatementClient.close();
        }
        Assertions.assertThat(this.server.getRequestCount()).isEqualTo(3);
    }

    private String newQueryResults(String str) {
        return QUERY_RESULTS_CODEC.toJson(new QueryResults("20160128_214710_00012_rk68b", this.server.url("/query.html?20160128_214710_00012_rk68b").uri(), (URI) null, str.equals("RUNNING") ? this.server.url(String.format("/v1/statement/%s/%s", "20160128_214710_00012_rk68b", "aa")).uri() : null, (List) Stream.of((Object[]) new Column[]{new Column("id", "integer", new ClientTypeSignature("integer")), new Column("name", "varchar", new ClientTypeSignature("varchar"))}).collect(Collectors.toList()), (List) IntStream.range(0, 10).mapToObj(i -> {
            return (List) Stream.of(Integer.valueOf(i), "a").collect(Collectors.toList());
        }).collect(Collectors.toList()), new StatementStats(str, str.equals("QUEUED"), true, OptionalDouble.of(0.0d), OptionalDouble.of(0.0d), 0, 0, 0, 0, 0, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, (StageStats) null), (QueryError) null, ImmutableList.of(), (String) null, (Long) null));
    }

    private static MockResponse statusAndBody(int i, String str) {
        return new MockResponse().setResponseCode(i).addHeader("Content-Type", MediaType.JSON_UTF_8).setBody(str);
    }
}
