package io.trino.jdbc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Resources;
import com.google.inject.Binder;
import com.google.inject.Inject;
import com.google.inject.Key;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.jaxrs.JaxrsBinder;
import io.airlift.log.Logging;
import io.airlift.testing.Closeables;
import io.trino.client.ClientException;
import io.trino.client.auth.external.DesktopBrowserRedirectHandler;
import io.trino.client.auth.external.RedirectException;
import io.trino.client.auth.external.RedirectHandler;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.Authenticator;
import io.trino.server.security.ResourceSecurity;
import io.trino.server.security.ServerSecurityModule;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.spi.security.Identity;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Duration;
import java.util.ConcurrentModificationException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.IntSupplier;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.SAME_THREAD)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication.class */
public class TestJdbcExternalAuthentication {
    private static final String TEST_CATALOG = "test_catalog";
    private TestingTrinoServer server;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$Authentications.class */
    public static class Authentications {
        private final Map<String, String> logginSessions = new ConcurrentHashMap();
        private final Set<String> validTokens = ConcurrentHashMap.newKeySet();

        private Authentications() {
        }

        public String startAuthentication() {
            String uuid = UUID.randomUUID().toString();
            this.logginSessions.put(uuid, "");
            return uuid;
        }

        public void logIn(String str) {
            String str2 = str + "_token";
            this.validTokens.add(str2);
            this.logginSessions.put(str, str2);
        }

        public Optional<String> getToken(String str) throws IllegalArgumentException {
            return Optional.ofNullable(this.logginSessions.get(str)).filter(str2 -> {
                return !str2.isEmpty();
            });
        }

        public boolean verifyToken(String str) {
            return this.validTokens.contains(str);
        }

        public void invalidateAllTokens() {
            this.validTokens.clear();
        }

        public int countValidTokens() {
            return this.validTokens.size();
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$DummyAuthenticator.class */
    public static class DummyAuthenticator implements Authenticator {
        private final IntSupplier port;
        private final Authentications authentications;

        @Inject
        public DummyAuthenticator(IntSupplier intSupplier, Authentications authentications) {
            this.port = (IntSupplier) Objects.requireNonNull(intSupplier, "port is null");
            this.authentications = (Authentications) Objects.requireNonNull(authentications, "authentications is null");
        }

        public Identity authenticate(ContainerRequestContext containerRequestContext) throws AuthenticationException {
            if (((List) containerRequestContext.getHeaders().getOrDefault("Authorization", ImmutableList.of())).stream().filter(str -> {
                return str.startsWith("Bearer ");
            }).anyMatch(str2 -> {
                return this.authentications.verifyToken(str2.substring("Bearer ".length()));
            })) {
                return Identity.ofUser("user");
            }
            String startAuthentication = this.authentications.startAuthentication();
            throw ((AuthenticationException) Optional.ofNullable((String) WwwAuthenticateHeaderFixture.HEADER.get()).map(str3 -> {
                return new AuthenticationException("Authentication required", str3);
            }).orElseGet(() -> {
                return new AuthenticationException("Authentication required", String.format("Bearer x_redirect_server=\"http://localhost:%s/v1/authentications/dummy/logins/%s\", x_token_server=\"http://localhost:%s/v1/authentications/dummy/%s\"", Integer.valueOf(this.port.getAsInt()), startAuthentication, Integer.valueOf(this.port.getAsInt()), startAuthentication));
            }));
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$DummyExternalAuthModule.class */
    private static class DummyExternalAuthModule extends AbstractConfigurationAwareModule {
        private final IntSupplier port;

        public DummyExternalAuthModule(IntSupplier intSupplier) {
            this.port = (IntSupplier) Objects.requireNonNull(intSupplier, "port is null");
        }

        protected void setup(Binder binder) {
            install(ServerSecurityModule.authenticatorModule("dummy-external", DummyAuthenticator.class, binder2 -> {
                binder2.bind(Authentications.class).in(Scopes.SINGLETON);
                binder2.bind(IntSupplier.class).toInstance(this.port);
                JaxrsBinder.jaxrsBinder(binder2).bind(DummyExternalAuthResources.class);
            }));
        }
    }

    @Path("/v1/authentications/dummy")
    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$DummyExternalAuthResources.class */
    public static class DummyExternalAuthResources {
        private final Authentications authentications;

        @Inject
        public DummyExternalAuthResources(Authentications authentications) {
            this.authentications = authentications;
        }

        @Produces({"text/plain"})
        @ResourceSecurity(ResourceSecurity.AccessType.PUBLIC)
        @GET
        @Path("logins/{sessionId}")
        public String logInUser(@PathParam("sessionId") String str) {
            this.authentications.logIn(str);
            return "User has been successfully logged in during " + str + " session";
        }

        @ResourceSecurity(ResourceSecurity.AccessType.PUBLIC)
        @GET
        @Path("{sessionId}")
        public Response getToken(@PathParam("sessionId") String str, @Context HttpServletRequest httpServletRequest) {
            try {
                return (Response) Optional.ofNullable((String) TokenPollingErrorFixture.ERROR.get()).map(str2 -> {
                    return Response.ok(String.format("{ \"error\" : \"%s\"}", str2), MediaType.APPLICATION_JSON_TYPE).build();
                }).orElseGet(() -> {
                    return (Response) this.authentications.getToken(str).map(str3 -> {
                        return Response.ok(String.format("{ \"token\" : \"%s\"}", str3), MediaType.APPLICATION_JSON_TYPE).build();
                    }).orElseGet(() -> {
                        return Response.ok(String.format("{ \"nextUri\" : \"%s\" }", httpServletRequest.getRequestURI()), MediaType.APPLICATION_JSON_TYPE).build();
                    });
                });
            } catch (IllegalArgumentException e) {
                return Response.status(Response.Status.NOT_FOUND).build();
            }
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$FailingRedirectHandler.class */
    public static class FailingRedirectHandler implements RedirectHandler {
        public void redirectTo(URI uri) throws RedirectException {
            throw new RedirectException("Redirect to uri has failed " + uri);
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$HttpGetOnlyRedirectHandler.class */
    public static class HttpGetOnlyRedirectHandler implements RedirectHandler {
        public void redirectTo(URI uri) throws RedirectException {
            try {
                okhttp3.Response execute = new OkHttpClient().newCall(new Request.Builder().url(HttpUrl.get(uri.toString())).build()).execute();
                try {
                    if (execute.code() != 200) {
                        throw new RedirectException("HTTP GET failed with status " + execute.code());
                    }
                    if (execute != null) {
                        execute.close();
                    }
                } finally {
                }
            } catch (IOException e) {
                throw new RedirectException("Redirection failed", e);
            }
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$NoOpRedirectHandler.class */
    public static class NoOpRedirectHandler implements RedirectHandler {
        public void redirectTo(URI uri) throws RedirectException {
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$RedirectHandlerFixture.class */
    static class RedirectHandlerFixture implements AutoCloseable {
        private static final RedirectHandlerFixture INSTANCE = new RedirectHandlerFixture();

        private RedirectHandlerFixture() {
        }

        public static RedirectHandlerFixture withHandler(RedirectHandler redirectHandler) {
            TrinoDriverUri.setRedirectHandler(redirectHandler);
            return INSTANCE;
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            TrinoDriverUri.setRedirectHandler(new DesktopBrowserRedirectHandler());
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$TokenPollingErrorFixture.class */
    static class TokenPollingErrorFixture implements AutoCloseable {
        private static final AtomicReference<String> ERROR = new AtomicReference<>(null);

        TokenPollingErrorFixture() {
        }

        public static AutoCloseable withPollingError(String str) {
            if (ERROR.compareAndSet(null, str)) {
                return new TokenPollingErrorFixture();
            }
            throw new ConcurrentModificationException("polling errors can't be invoked in parallel");
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            ERROR.set(null);
        }
    }

    /* loaded from: input_file:io/trino/jdbc/TestJdbcExternalAuthentication$WwwAuthenticateHeaderFixture.class */
    static class WwwAuthenticateHeaderFixture implements AutoCloseable {
        private static final AtomicReference<String> HEADER = new AtomicReference<>(null);

        WwwAuthenticateHeaderFixture() {
        }

        public static AutoCloseable withWwwAuthenticate(String str) {
            if (HEADER.compareAndSet(null, str)) {
                return new WwwAuthenticateHeaderFixture();
            }
            throw new ConcurrentModificationException("with WWW-Authenticate header can't be invoked in parallel");
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            HEADER.set(null);
        }
    }

    @BeforeAll
    public void setup() throws Exception {
        Logging.initialize();
        this.server = TestingTrinoServer.builder().setAdditionalModule(new DummyExternalAuthModule(() -> {
            return this.server.getAddress().getPort();
        })).setProperties(ImmutableMap.builder().put("http-server.authentication.type", "dummy-external").put("http-server.https.enabled", "true").put("http-server.https.keystore.path", new File(Resources.getResource("localhost.keystore").toURI()).getPath()).put("http-server.https.keystore.key", "changeit").put("web-ui.enabled", "false").buildOrThrow()).build();
        this.server.installPlugin(new TpchPlugin());
        this.server.createCatalog(TEST_CATALOG, "tpch");
        this.server.waitForNodeRefresh(Duration.ofSeconds(10L));
    }

    @AfterAll
    public void teardown() throws Exception {
        Closeables.closeAll(new Closeable[]{this.server});
        this.server = null;
    }

    @Test
    public void testSuccessfulAuthenticationWithHttpGetOnlyRedirectHandler() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
        try {
            Connection createConnection = createConnection();
            try {
                Statement createStatement = createConnection.createStatement();
                try {
                    Assertions.assertThat(createStatement.execute("SELECT 123")).isTrue();
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (createConnection != null) {
                        createConnection.close();
                    }
                    if (withHandler != null) {
                        withHandler.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th3) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Disabled
    @Test
    public void testSuccessfulAuthenticationWithDefaultBrowserRedirect() throws Exception {
        invalidateAllTokens();
        Connection createConnection = createConnection();
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                Assertions.assertThat(createStatement.execute("SELECT 123")).isTrue();
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testAuthenticationFailsAfterUnfinishedRedirect() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new NoOpRedirectHandler());
        try {
            Connection createConnection = createConnection();
            try {
                Statement createStatement = createConnection.createStatement();
                try {
                    Assertions.assertThatThrownBy(() -> {
                        createStatement.execute("SELECT 123");
                    }).isInstanceOf(SQLException.class);
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (createConnection != null) {
                        createConnection.close();
                    }
                    if (withHandler != null) {
                        withHandler.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th3) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testAuthenticationFailsAfterRedirectException() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new FailingRedirectHandler());
        try {
            Connection createConnection = createConnection();
            try {
                Statement createStatement = createConnection.createStatement();
                try {
                    Assertions.assertThatThrownBy(() -> {
                        createStatement.execute("SELECT 123");
                    }).isInstanceOf(SQLException.class).hasCauseExactlyInstanceOf(RedirectException.class);
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (createConnection != null) {
                        createConnection.close();
                    }
                    if (withHandler != null) {
                        withHandler.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th3) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testAuthenticationFailsAfterServerAuthenticationFailure() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
        try {
            AutoCloseable withPollingError = TokenPollingErrorFixture.withPollingError("error occurred during token polling");
            try {
                Connection createConnection = createConnection();
                try {
                    Statement createStatement = createConnection.createStatement();
                    try {
                        Assertions.assertThatThrownBy(() -> {
                            createStatement.execute("SELECT 123");
                        }).isInstanceOf(SQLException.class).hasMessage("error occurred during token polling");
                        if (createStatement != null) {
                            createStatement.close();
                        }
                        if (createConnection != null) {
                            createConnection.close();
                        }
                        if (withPollingError != null) {
                            withPollingError.close();
                        }
                        if (withHandler != null) {
                            withHandler.close();
                        }
                    } catch (Throwable th) {
                        if (createStatement != null) {
                            try {
                                createStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    if (createConnection != null) {
                        try {
                            createConnection.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } catch (Throwable th5) {
                if (withPollingError != null) {
                    try {
                        withPollingError.close();
                    } catch (Throwable th6) {
                        th5.addSuppressed(th6);
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th8) {
                    th7.addSuppressed(th8);
                }
            }
            throw th7;
        }
    }

    @Test
    public void testAuthenticationFailsAfterReceivingMalformedHeaderFromServer() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
        try {
            AutoCloseable withWwwAuthenticate = WwwAuthenticateHeaderFixture.withWwwAuthenticate("Bearer no-valid-fields");
            try {
                Connection createConnection = createConnection();
                try {
                    Statement createStatement = createConnection.createStatement();
                    try {
                        Assertions.assertThatThrownBy(() -> {
                            createStatement.execute("SELECT 123");
                        }).isInstanceOf(SQLException.class).hasCauseInstanceOf(ClientException.class).hasMessage("Authentication failed: Authentication required");
                        if (createStatement != null) {
                            createStatement.close();
                        }
                        if (createConnection != null) {
                            createConnection.close();
                        }
                        if (withWwwAuthenticate != null) {
                            withWwwAuthenticate.close();
                        }
                        if (withHandler != null) {
                            withHandler.close();
                        }
                    } catch (Throwable th) {
                        if (createStatement != null) {
                            try {
                                createStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    if (createConnection != null) {
                        try {
                            createConnection.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } catch (Throwable th5) {
                if (withWwwAuthenticate != null) {
                    try {
                        withWwwAuthenticate.close();
                    } catch (Throwable th6) {
                        th5.addSuppressed(th6);
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th8) {
                    th7.addSuppressed(th8);
                }
            }
            throw th7;
        }
    }

    @Test
    public void testAuthenticationReusesObtainedTokenPerConnection() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
        try {
            Connection createConnection = createConnection();
            try {
                Statement createStatement = createConnection.createStatement();
                try {
                    createStatement.execute("SELECT 123");
                    createStatement.execute("SELECT 123");
                    createStatement.execute("SELECT 123");
                    Assertions.assertThat(countIssuedTokens()).isEqualTo(1);
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (createConnection != null) {
                        createConnection.close();
                    }
                    if (withHandler != null) {
                        withHandler.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th3) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testAuthenticationAfterInitialTokenHasBeenInvalidated() throws Exception {
        invalidateAllTokens();
        RedirectHandlerFixture withHandler = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
        try {
            Connection createConnection = createConnection();
            try {
                Statement createStatement = createConnection.createStatement();
                try {
                    createStatement.execute("SELECT 123");
                    invalidateAllTokens();
                    Assertions.assertThat(countIssuedTokens()).isEqualTo(0);
                    Assertions.assertThat(createStatement.execute("SELECT 123")).isTrue();
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (createConnection != null) {
                        createConnection.close();
                    }
                    if (withHandler != null) {
                        withHandler.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th3) {
            if (withHandler != null) {
                try {
                    withHandler.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    private Connection createConnection() throws Exception {
        String format = String.format("jdbc:trino://localhost:%s", Integer.valueOf(this.server.getHttpsAddress().getPort()));
        Properties properties = new Properties();
        properties.setProperty("SSL", "true");
        properties.setProperty("SSLTrustStorePath", new File(Resources.getResource("localhost.truststore").toURI()).getPath());
        properties.setProperty("SSLTrustStorePassword", "changeit");
        properties.setProperty("externalAuthentication", "true");
        properties.setProperty("externalAuthenticationTimeout", "2s");
        return DriverManager.getConnection(format, properties);
    }

    private void invalidateAllTokens() {
        ((Authentications) this.server.getInstance(Key.get(Authentications.class))).invalidateAllTokens();
    }

    private int countIssuedTokens() {
        return ((Authentications) this.server.getInstance(Key.get(Authentications.class))).countValidTokens();
    }
}
