package io.trino.jdbc;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.Files;
import com.google.common.io.Resources;
import io.airlift.log.Logging;
import io.airlift.security.pem.PemReader;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.testing.TestingTrinoServer;
import java.io.File;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.PrivateKey;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/jdbc/TestTrinoDriverAuth.class */
public class TestTrinoDriverAuth {
    private static final String TEST_CATALOG = "test_catalog";
    private TestingTrinoServer server;
    private Key defaultKey;
    private Key hmac222;
    private PrivateKey privateKey33;

    @BeforeClass
    public void setup() throws Exception {
        Logging.initialize();
        URL resource = getClass().getClassLoader().getResource("33.privateKey");
        Assert.assertNotNull(resource, "key directory not found");
        File parentFile = new File(resource.toURI()).getAbsoluteFile().getParentFile();
        this.defaultKey = Keys.hmacShaKeyFor(Base64.getMimeDecoder().decode(Files.asCharSource(new File(parentFile, "default-key.key"), StandardCharsets.US_ASCII).read().getBytes(StandardCharsets.US_ASCII)));
        this.hmac222 = Keys.hmacShaKeyFor(Base64.getMimeDecoder().decode(Files.asCharSource(new File(parentFile, "222.key"), StandardCharsets.US_ASCII).read().getBytes(StandardCharsets.US_ASCII)));
        this.privateKey33 = PemReader.loadPrivateKey(new File(parentFile, "33.privateKey"), Optional.empty());
        this.server = TestingTrinoServer.builder().setProperties(ImmutableMap.builder().put("http-server.authentication.type", "JWT").put("http-server.authentication.jwt.key-file", new File(parentFile, "${KID}.key").getPath()).put("http-server.https.enabled", "true").put("http-server.https.keystore.path", new File(Resources.getResource("localhost.keystore").toURI()).getPath()).put("http-server.https.keystore.key", "changeit").buildOrThrow()).build();
        this.server.installPlugin(new TpchPlugin());
        this.server.createCatalog(TEST_CATALOG, "tpch");
        this.server.waitForNodeRefresh(Duration.ofSeconds(10L));
    }

    @AfterClass(alwaysRun = true)
    public void teardown() throws Exception {
        this.server.close();
    }

    @Test
    public void testSuccessDefaultKey() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").signWith(this.defaultKey).compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                Assert.assertTrue(createStatement.execute("SELECT 123"));
                ResultSet resultSet = createStatement.getResultSet();
                Assert.assertTrue(resultSet.next());
                Assert.assertEquals(resultSet.getLong(1), 123L);
                Assert.assertFalse(resultSet.next());
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testSuccessHmac() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").setHeaderParam("kid", "222").signWith(this.hmac222).compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                Assert.assertTrue(createStatement.execute("SELECT 123"));
                ResultSet resultSet = createStatement.getResultSet();
                Assert.assertTrue(resultSet.next());
                Assert.assertEquals(resultSet.getLong(1), 123L);
                Assert.assertFalse(resultSet.next());
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testSuccessPublicKey() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").setHeaderParam("kid", "33").signWith(this.privateKey33).compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                Assert.assertTrue(createStatement.execute("SELECT 123"));
                ResultSet resultSet = createStatement.getResultSet();
                Assert.assertTrue(resultSet.next());
                Assert.assertEquals(resultSet.getLong(1), 123L);
                Assert.assertFalse(resultSet.next());
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test(expectedExceptions = {SQLException.class}, expectedExceptionsMessageRegExp = "Authentication failed: Unauthorized")
    public void testFailedNoToken() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of());
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                createStatement.execute("SELECT 123");
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test(expectedExceptions = {SQLException.class}, expectedExceptionsMessageRegExp = "Authentication failed: Unsigned Claims JWTs are not supported.")
    public void testFailedUnsigned() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                createStatement.execute("SELECT 123");
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test(expectedExceptions = {SQLException.class}, expectedExceptionsMessageRegExp = "Authentication failed: JWT signature does not match.*")
    public void testFailedBadHmacSignature() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").signWith(Keys.secretKeyFor(SignatureAlgorithm.HS512)).compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                createStatement.execute("SELECT 123");
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test(expectedExceptions = {SQLException.class}, expectedExceptionsMessageRegExp = "Authentication failed: JWT signature does not match.*")
    public void testFailedWrongPublicKey() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").setHeaderParam("kid", "42").signWith(this.privateKey33).compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                createStatement.execute("SELECT 123");
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test(expectedExceptions = {SQLException.class}, expectedExceptionsMessageRegExp = "Authentication failed: Unknown signing key ID")
    public void testFailedUnknownPublicKey() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").setHeaderParam("kid", "unknown").signWith(this.privateKey33).compact()));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                createStatement.execute("SELECT 123");
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testSuccessFullSslVerification() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").setHeaderParam("kid", "33").signWith(this.privateKey33).compact(), "SSLVerification", "FULL"));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                Assert.assertTrue(createStatement.execute("SELECT 123"));
                ResultSet resultSet = createStatement.getResultSet();
                Assert.assertTrue(resultSet.next());
                Assert.assertEquals(resultSet.getLong(1), 123L);
                Assert.assertFalse(resultSet.next());
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testSuccessCaSslVerification() throws Exception {
        Connection createConnection = createConnection(ImmutableMap.of("accessToken", JwtUtil.newJwtBuilder().setSubject("test").setHeaderParam("kid", "33").signWith(this.privateKey33).compact(), "SSLVerification", "CA"));
        try {
            Statement createStatement = createConnection.createStatement();
            try {
                Assert.assertTrue(createStatement.execute("SELECT 123"));
                ResultSet resultSet = createStatement.getResultSet();
                Assert.assertTrue(resultSet.next());
                Assert.assertEquals(resultSet.getLong(1), 123L);
                Assert.assertFalse(resultSet.next());
                if (createStatement != null) {
                    createStatement.close();
                }
                if (createConnection != null) {
                    createConnection.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createConnection != null) {
                try {
                    createConnection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testFailedFullSslVerificationWithoutSSL() {
        Assertions.assertThatThrownBy(() -> {
            createBasicConnection(ImmutableMap.of("SSLVerification", "FULL"));
        }).isInstanceOf(SQLException.class).hasMessage("Connection property 'SSLVerification' is not allowed");
    }

    @Test
    public void testFailedCaSslVerificationWithoutSSL() {
        Assertions.assertThatThrownBy(() -> {
            createBasicConnection(ImmutableMap.of("SSLVerification", "CA"));
        }).isInstanceOf(SQLException.class).hasMessage("Connection property 'SSLVerification' is not allowed");
    }

    @Test
    public void testFailedNoneSslVerificationWithSSL() {
        Assertions.assertThatThrownBy(() -> {
            createConnection(ImmutableMap.of("SSLVerification", "NONE"));
        }).isInstanceOf(SQLException.class).hasMessage("Connection property 'SSLTrustStorePath' is not allowed");
    }

    @Test
    public void testFailedNoneSslVerificationWithSSLUnsigned() throws Exception {
        Statement createStatement = createBasicConnection(ImmutableMap.of("SSL", "true", "SSLVerification", "NONE")).createStatement();
        Assertions.assertThatThrownBy(() -> {
            createStatement.execute("SELECT 123");
        }).isInstanceOf(SQLException.class).hasMessage("Authentication failed: Unauthorized");
    }

    private Connection createConnection(Map<String, String> map) throws Exception {
        String format = String.format("jdbc:trino://localhost:%s", Integer.valueOf(this.server.getHttpsAddress().getPort()));
        Properties properties = new Properties();
        properties.setProperty("SSL", "true");
        properties.setProperty("SSLTrustStorePath", new File(Resources.getResource("localhost.truststore").toURI()).getPath());
        properties.setProperty("SSLTrustStorePassword", "changeit");
        Objects.requireNonNull(properties);
        map.forEach(properties::setProperty);
        return DriverManager.getConnection(format, properties);
    }

    private Connection createBasicConnection(Map<String, String> map) throws SQLException {
        String format = String.format("jdbc:trino://localhost:%s", Integer.valueOf(this.server.getHttpsAddress().getPort()));
        Properties properties = new Properties();
        Objects.requireNonNull(properties);
        map.forEach(properties::setProperty);
        return DriverManager.getConnection(format, properties);
    }
}
