package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.log.Level;
import io.airlift.log.Logging;
import io.airlift.testing.Closeables;
import io.jsonwebtoken.impl.DefaultClaims;
import io.trino.client.OkHttpUtil;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.server.ui.OAuth2WebUiAuthenticationFilter;
import io.trino.server.ui.WebUiModule;
import java.io.IOException;
import java.net.CookieManager;
import java.net.CookieStore;
import java.net.HttpCookie;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import javax.ws.rs.core.Response;
import okhttp3.Cookie;
import okhttp3.CookieJar;
import okhttp3.HttpUrl;
import okhttp3.JavaNetCookieJar;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/server/security/oauth2/BaseOAuth2WebUiAuthenticationFilterTest.class */
public abstract class BaseOAuth2WebUiAuthenticationFilterTest {
    protected static final Duration TTL_ACCESS_TOKEN_IN_SECONDS = Duration.ofSeconds(5);
    protected static final String TRINO_CLIENT_ID = "trino-client";
    protected static final String TRINO_CLIENT_SECRET = "trino-secret";
    private static final String TRINO_AUDIENCE = "trino-client";
    private static final String ADDITIONAL_AUDIENCE = "https://external-service.com";
    protected static final String TRUSTED_CLIENT_ID = "trusted-client";
    protected static final String TRUSTED_CLIENT_SECRET = "trusted-secret";
    private static final String UNTRUSTED_CLIENT_ID = "untrusted-client";
    private static final String UNTRUSTED_CLIENT_SECRET = "untrusted-secret";
    private static final String UNTRUSTED_CLIENT_AUDIENCE = "https://untrusted.com";
    private final Logging logging = Logging.initialize();
    protected final OkHttpClient httpClient;
    protected TestingHydraIdentityProvider hydraIdP;
    private TestingTrinoServer server;
    private URI serverUri;
    private URI uiUri;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseOAuth2WebUiAuthenticationFilterTest() {
        OkHttpClient.Builder builder = new OkHttpClient.Builder();
        OkHttpUtil.setupInsecureSsl(builder);
        builder.followRedirects(false);
        this.httpClient = builder.build();
    }

    @BeforeClass
    public void setup() throws Exception {
        this.logging.setLevel(OAuth2WebUiAuthenticationFilter.class.getName(), Level.DEBUG);
        this.logging.setLevel(OAuth2Service.class.getName(), Level.DEBUG);
        this.hydraIdP = getHydraIdp();
        this.server = TestingTrinoServer.builder().setCoordinator(true).setAdditionalModule(new WebUiModule()).setProperties(getOAuth2Config("https://localhost:" + this.hydraIdP.getAuthPort())).build();
        this.server.waitForNodeRefresh(Duration.ofSeconds(10L));
        this.serverUri = this.server.getHttpsBaseUrl();
        this.uiUri = this.serverUri.resolve("/ui/");
        this.hydraIdP.createClient("trino-client", TRINO_CLIENT_SECRET, TokenEndpointAuthMethod.CLIENT_SECRET_BASIC, ImmutableList.of("trino-client", ADDITIONAL_AUDIENCE), this.serverUri + "/oauth2/callback");
        this.hydraIdP.createClient(TRUSTED_CLIENT_ID, TRUSTED_CLIENT_SECRET, TokenEndpointAuthMethod.CLIENT_SECRET_BASIC, ImmutableList.of(TRUSTED_CLIENT_ID), this.serverUri + "/oauth2/callback");
        this.hydraIdP.createClient(UNTRUSTED_CLIENT_ID, UNTRUSTED_CLIENT_SECRET, TokenEndpointAuthMethod.CLIENT_SECRET_BASIC, ImmutableList.of(UNTRUSTED_CLIENT_AUDIENCE), "https://untrusted.com/callback");
    }

    protected abstract ImmutableMap<String, String> getOAuth2Config(String str);

    protected abstract TestingHydraIdentityProvider getHydraIdp() throws Exception;

    @AfterClass(alwaysRun = true)
    public void tearDown() throws Exception {
        this.logging.clearLevel(OAuth2WebUiAuthenticationFilter.class.getName());
        this.logging.clearLevel(OAuth2Service.class.getName());
        Closeables.closeAll(new AutoCloseable[]{this.server, this.hydraIdP});
    }

    @Test
    public void testUnauthorizedApiCall() throws IOException {
        Response execute = this.httpClient.newCall(apiCall().build()).execute();
        try {
            assertUnauthorizedResponse(execute);
            if (execute != null) {
                execute.close();
            }
        } catch (Throwable th) {
            if (execute != null) {
                try {
                    execute.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testUnauthorizedUICall() throws IOException {
        Response execute = this.httpClient.newCall(uiCall().build()).execute();
        try {
            assertRedirectResponse(execute);
            if (execute != null) {
                execute.close();
            }
        } catch (Throwable th) {
            if (execute != null) {
                try {
                    execute.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testUnsignedToken() throws NoSuchAlgorithmException, IOException {
        KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
        keyPairGenerator.initialize(4096);
        long epochSecond = Instant.now().getEpochSecond();
        Response execute = httpClientWithOAuth2Cookie(JwtUtil.newJwtBuilder().setHeaderParam("alg", "RS256").setHeaderParam("kid", "public:f467aa08-1c1b-4cde-ba45-84b0ef5d2ba8").setHeaderParam("typ", "JWT").setClaims(new DefaultClaims(ImmutableMap.builder().put("aud", ImmutableList.of()).put("client_id", "trino-client").put("exp", Long.valueOf(epochSecond + 60)).put("iat", Long.valueOf(epochSecond)).put("iss", "https://hydra:4444/").put("jti", UUID.randomUUID()).put("nbf", Long.valueOf(epochSecond)).put("scp", ImmutableList.of("openid")).put("sub", "foo@bar.com").buildOrThrow())).signWith(keyPairGenerator.generateKeyPair().getPrivate()).compact(), false).newCall(uiCall().build()).execute();
        try {
            assertRedirectResponse(execute);
            if (execute != null) {
                execute.close();
            }
        } catch (Throwable th) {
            if (execute != null) {
                try {
                    execute.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testTokenWithInvalidAudience() throws IOException {
        Response execute = httpClientWithOAuth2Cookie(this.hydraIdP.getToken(UNTRUSTED_CLIENT_ID, UNTRUSTED_CLIENT_SECRET, ImmutableList.of(UNTRUSTED_CLIENT_AUDIENCE)), true).newCall(uiCall().build()).execute();
        try {
            assertUnauthorizedResponse(execute);
            if (execute != null) {
                execute.close();
            }
        } catch (Throwable th) {
            if (execute != null) {
                try {
                    execute.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testTokenFromTrustedClient() throws IOException {
        assertUICallWithCookie(this.hydraIdP.getToken(TRUSTED_CLIENT_ID, TRUSTED_CLIENT_SECRET, ImmutableList.of(TRUSTED_CLIENT_ID)));
    }

    @Test
    public void testTokenWithMultipleAudiences() throws IOException {
        assertUICallWithCookie(this.hydraIdP.getToken("trino-client", TRINO_CLIENT_SECRET, ImmutableList.of("trino-client", ADDITIONAL_AUDIENCE)));
    }

    @Test
    public void testSuccessfulFlow() throws Exception {
        CookieManager cookieManager = new CookieManager();
        CookieStore cookieStore = cookieManager.getCookieStore();
        OkHttpClient.Builder builder = new OkHttpClient.Builder();
        OkHttpUtil.setupInsecureSsl(builder);
        OkHttpClient build = builder.followRedirects(true).cookieJar(new JavaNetCookieJar(cookieManager)).build();
        Assertions.assertThat(cookieStore.get(this.uiUri)).isEmpty();
        Response execute = build.newCall(new Request.Builder().url(this.uiUri.toURL()).get().build()).execute();
        Assert.assertEquals(execute.code(), 200);
        Assert.assertEquals(execute.request().url().toString(), this.uiUri.toString());
        Optional<HttpCookie> findFirst = cookieStore.get(this.uiUri).stream().filter(httpCookie -> {
            return httpCookie.getName().equals("__Secure-Trino-OAuth2-Token");
        }).findFirst();
        Assertions.assertThat(findFirst).isNotEmpty();
        assertTrinoCookie(findFirst.get());
        assertUICallWithCookie(findFirst.get().getValue());
    }

    @Test
    public void testExpiredAccessToken() throws Exception {
        String token = this.hydraIdP.getToken("trino-client", TRINO_CLIENT_SECRET, ImmutableList.of("trino-client"));
        assertUICallWithCookie(token);
        Thread.sleep(TTL_ACCESS_TOKEN_IN_SECONDS.plusSeconds(1L).toMillis());
        Response execute = httpClientWithOAuth2Cookie(token, false).newCall(uiCall().build()).execute();
        try {
            assertRedirectResponse(execute);
            if (execute != null) {
                execute.close();
            }
        } catch (Throwable th) {
            if (execute != null) {
                try {
                    execute.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private Request.Builder uiCall() {
        return new Request.Builder().url(this.serverUri.resolve("/ui/").toString()).get();
    }

    private Request.Builder apiCall() {
        return new Request.Builder().url(this.serverUri.resolve("/ui/api/cluster").toString()).get();
    }

    private void assertTrinoCookie(HttpCookie httpCookie) {
        Assertions.assertThat(httpCookie.getName()).isEqualTo("__Secure-Trino-OAuth2-Token");
        Assertions.assertThat(httpCookie.getDomain()).isIn(new Object[]{"127.0.0.1", "::1"});
        Assertions.assertThat(httpCookie.getPath()).isEqualTo("/ui/");
        Assertions.assertThat(httpCookie.getSecure()).isTrue();
        Assertions.assertThat(httpCookie.isHttpOnly()).isTrue();
        Assertions.assertThat(httpCookie.getMaxAge()).isLessThanOrEqualTo(TTL_ACCESS_TOKEN_IN_SECONDS.getSeconds());
        validateAccessToken(httpCookie.getValue());
    }

    protected abstract void validateAccessToken(String str);

    private void assertUICallWithCookie(String str) throws IOException {
        Response execute = httpClientWithOAuth2Cookie(str, true).newCall(uiCall().build()).execute();
        try {
            Assertions.assertThat(execute.code()).isEqualTo(Response.Status.OK.getStatusCode());
            if (execute != null) {
                execute.close();
            }
        } catch (Throwable th) {
            if (execute != null) {
                try {
                    execute.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private OkHttpClient httpClientWithOAuth2Cookie(final String str, boolean z) {
        OkHttpClient.Builder builder = new OkHttpClient.Builder();
        OkHttpUtil.setupInsecureSsl(builder);
        builder.followRedirects(z);
        builder.cookieJar(new CookieJar() { // from class: io.trino.server.security.oauth2.BaseOAuth2WebUiAuthenticationFilterTest.1
            public void saveFromResponse(HttpUrl httpUrl, List<Cookie> list) {
            }

            public List<Cookie> loadForRequest(HttpUrl httpUrl) {
                return httpUrl.encodedPath().equals("/ui/") ? ImmutableList.of(new Cookie.Builder().domain(BaseOAuth2WebUiAuthenticationFilterTest.this.serverUri.getHost()).path("/ui/").name("__Secure-Trino-OAuth2-Token").value(str).httpOnly().secure().build()) : ImmutableList.of();
            }
        });
        return builder.build();
    }

    private void assertRedirectResponse(okhttp3.Response response) throws MalformedURLException {
        Assertions.assertThat(response.code()).isEqualTo(Response.Status.SEE_OTHER.getStatusCode());
        assertRedirectUrl(response.header("Location"));
    }

    private void assertUnauthorizedResponse(okhttp3.Response response) throws IOException {
        Assertions.assertThat(response.code()).isEqualTo(Response.Status.UNAUTHORIZED.getStatusCode());
        Assertions.assertThat(response.body()).isNotNull();
        Assertions.assertThat(response.body().string()).isEqualTo("Unauthorized");
    }

    private void assertRedirectUrl(String str) throws MalformedURLException {
        Assertions.assertThat(str).isNotNull();
        URL url = new URL(str);
        HttpUrl parse = HttpUrl.parse(str);
        Assertions.assertThat(parse).isNotNull();
        Assertions.assertThat(url.getProtocol()).isEqualTo("https");
        Assertions.assertThat(url.getHost()).isEqualTo("localhost");
        Assertions.assertThat(url.getPort()).isEqualTo(this.hydraIdP.getAuthPort());
        Assertions.assertThat(url.getPath()).isEqualTo("/oauth2/auth");
        Assertions.assertThat(parse.queryParameterValues("response_type")).isEqualTo(ImmutableList.of("code"));
        Assertions.assertThat(parse.queryParameterValues("scope")).isEqualTo(ImmutableList.of("openid"));
        Assertions.assertThat(parse.queryParameterValues("redirect_uri")).isEqualTo(ImmutableList.of(this.serverUri + "/oauth2/callback"));
        Assertions.assertThat(parse.queryParameterValues("client_id")).isEqualTo(ImmutableList.of("trino-client"));
        Assertions.assertThat(parse.queryParameterValues("state")).isNotNull();
    }
}
