package software.amazon.jdbc.plugin.federatedauth;

import java.sql.Connection;
import java.sql.SQLException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import shaded.software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import shaded.software.amazon.awssdk.regions.Region;
import shaded.software.amazon.awssdk.services.rds.RdsUtilities;
import software.amazon.jdbc.AwsWrapperProperty;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.JdbcCallable;
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.plugin.AbstractConnectionPlugin;
import software.amazon.jdbc.plugin.TokenInfo;
import software.amazon.jdbc.util.IamAuthUtils;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
import software.amazon.jdbc.util.telemetry.TelemetryCounter;
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
import software.amazon.jdbc.util.telemetry.TelemetryGauge;
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;

/* loaded from: input_file:software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.class */
public class FederatedAuthPlugin extends AbstractConnectionPlugin {
    private final CredentialsProviderFactory credentialsProviderFactory;
    protected static final String SAML_RESPONSE_PATTERN_GROUP = "saml";
    private static final String TELEMETRY_FETCH_TOKEN = "fetch IAM token";
    protected final PluginService pluginService;
    protected final RdsUtils rdsUtils = new RdsUtils();
    private final TelemetryFactory telemetryFactory;
    private final TelemetryGauge cacheSizeGauge;
    private final TelemetryCounter fetchTokenCounter;
    static final ConcurrentHashMap<String, TokenInfo> tokenCache = new ConcurrentHashMap<>();
    public static final AwsWrapperProperty IDP_ENDPOINT = new AwsWrapperProperty("idpEndpoint", null, "The hosting URL of the Identity Provider");
    public static final AwsWrapperProperty IDP_PORT = new AwsWrapperProperty("idpPort", "443", "The hosting port of Identity Provider");
    public static final AwsWrapperProperty RELAYING_PARTY_ID = new AwsWrapperProperty("rpIdentifier", "urn:amazon:webservices", "The relaying party identifier");
    public static final AwsWrapperProperty IAM_ROLE_ARN = new AwsWrapperProperty("iamRoleArn", null, "The ARN of the IAM Role that is to be assumed.");
    public static final AwsWrapperProperty IAM_IDP_ARN = new AwsWrapperProperty("iamIdpArn", null, "The ARN of the Identity Provider");
    public static final AwsWrapperProperty IAM_REGION = new AwsWrapperProperty("iamRegion", null, "Overrides AWS region that is used to generate the IAM token");
    private static final int DEFAULT_TOKEN_EXPIRATION_SEC = 870;
    public static final AwsWrapperProperty IAM_TOKEN_EXPIRATION = new AwsWrapperProperty("iamTokenExpiration", String.valueOf(DEFAULT_TOKEN_EXPIRATION_SEC), "IAM token cache expiration in seconds");
    public static final AwsWrapperProperty IDP_USERNAME = new AwsWrapperProperty("idpUsername", null, "The federated user name");
    public static final AwsWrapperProperty IDP_PASSWORD = new AwsWrapperProperty("idpPassword", null, "The federated user password");
    public static final AwsWrapperProperty IAM_HOST = new AwsWrapperProperty("iamHost", null, "Overrides the host that is used to generate the IAM token");
    public static final AwsWrapperProperty IAM_DEFAULT_PORT = new AwsWrapperProperty("iamDefaultPort", "-1", "Overrides default port that is used to generate the IAM token");
    private static final int DEFAULT_HTTP_TIMEOUT_MILLIS = 60000;
    public static final AwsWrapperProperty HTTP_CLIENT_SOCKET_TIMEOUT = new AwsWrapperProperty("httpClientSocketTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT_MILLIS), "The socket timeout value in milliseconds for the HttpClient used by the FederatedAuthPlugin");
    public static final AwsWrapperProperty HTTP_CLIENT_CONNECT_TIMEOUT = new AwsWrapperProperty("httpClientConnectTimeout", String.valueOf(DEFAULT_HTTP_TIMEOUT_MILLIS), "The connect timeout value in milliseconds for the HttpClient used by the FederatedAuthPlugin");
    public static final AwsWrapperProperty SSL_INSECURE = new AwsWrapperProperty("sslInsecure", "true", "Whether or not the SSL session is to be secure and the sever's certificates will be verified");
    public static AwsWrapperProperty IDP_NAME = new AwsWrapperProperty("idpName", null, "The name of the Identity Provider implementation used");
    public static final AwsWrapperProperty DB_USER = new AwsWrapperProperty("dbUser", null, "The database user used to access the database");
    protected static final Pattern SAML_RESPONSE_PATTERN = Pattern.compile("SAMLResponse\\W+value=\"(?<saml>[^\"]+)\"");
    protected static final Pattern HTTPS_URL_PATTERN = Pattern.compile("^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']");
    private static final Logger LOGGER = Logger.getLogger(FederatedAuthPlugin.class.getName());
    private static final Set<String> subscribedMethods = Collections.unmodifiableSet(new HashSet<String>() { // from class: software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin.1
        {
            add("connect");
            add("forceConnect");
        }
    });

    @Override // software.amazon.jdbc.plugin.AbstractConnectionPlugin, software.amazon.jdbc.ConnectionPlugin
    public Set<String> getSubscribedMethods() {
        return subscribedMethods;
    }

    public FederatedAuthPlugin(PluginService pluginService, CredentialsProviderFactory credentialsProviderFactory) {
        try {
            Class.forName("shaded.software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest");
        } catch (ClassNotFoundException e) {
            try {
                Class.forName("shaded.software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest");
            } catch (ClassNotFoundException e2) {
                throw new RuntimeException(Messages.get("FederatedAuthPlugin.javaStsSdkNotInClasspath"));
            }
        }
        this.pluginService = pluginService;
        this.credentialsProviderFactory = credentialsProviderFactory;
        this.telemetryFactory = pluginService.getTelemetryFactory();
        this.cacheSizeGauge = this.telemetryFactory.createGauge("federatedAuth.tokenCache.size", () -> {
            return Long.valueOf(tokenCache.size());
        });
        this.fetchTokenCounter = this.telemetryFactory.createCounter("federatedAuth.fetchToken.count");
    }

    @Override // software.amazon.jdbc.plugin.AbstractConnectionPlugin, software.amazon.jdbc.ConnectionPlugin
    public Connection connect(String str, HostSpec hostSpec, Properties properties, boolean z, JdbcCallable<Connection, SQLException> jdbcCallable) throws SQLException {
        return connectInternal(hostSpec, properties, jdbcCallable);
    }

    @Override // software.amazon.jdbc.plugin.AbstractConnectionPlugin, software.amazon.jdbc.ConnectionPlugin
    public Connection forceConnect(String str, HostSpec hostSpec, Properties properties, boolean z, JdbcCallable<Connection, SQLException> jdbcCallable) throws SQLException {
        return connectInternal(hostSpec, properties, jdbcCallable);
    }

    private Connection connectInternal(HostSpec hostSpec, Properties properties, JdbcCallable<Connection, SQLException> jdbcCallable) throws SQLException {
        checkIdpCredentialsWithFallback(properties);
        String iamHost = IamAuthUtils.getIamHost(IAM_HOST.getString(properties), hostSpec);
        int iamPort = IamAuthUtils.getIamPort(IAM_DEFAULT_PORT.getInteger(properties), hostSpec, this.pluginService.getDialect().getDefaultPort());
        Region region = getRegion(iamHost, properties);
        String cacheKey = getCacheKey(DB_USER.getString(properties), iamHost, iamPort, region);
        TokenInfo tokenInfo = tokenCache.get(cacheKey);
        if ((tokenInfo == null || tokenInfo.isExpired()) ? false : true) {
            LOGGER.finest(() -> {
                return Messages.get("FederatedAuthPlugin.useCachedIamToken", new Object[]{tokenInfo.getToken()});
            });
            PropertyDefinition.PASSWORD.set(properties, tokenInfo.getToken());
        } else {
            updateAuthenticationToken(hostSpec, properties, region, cacheKey);
        }
        PropertyDefinition.USER.set(properties, DB_USER.getString(properties));
        try {
            return jdbcCallable.call();
        } catch (SQLException e) {
            updateAuthenticationToken(hostSpec, properties, region, cacheKey);
            return jdbcCallable.call();
        } catch (Exception e2) {
            LOGGER.warning(() -> {
                return Messages.get("FederatedAuthPlugin.unhandledException", new Object[]{e2});
            });
            throw new SQLException(e2);
        }
    }

    private void checkIdpCredentialsWithFallback(Properties properties) {
        if (IDP_USERNAME.getString(properties) == null) {
            IDP_USERNAME.set(properties, PropertyDefinition.USER.getString(properties));
        }
        if (IDP_PASSWORD.getString(properties) == null) {
            IDP_PASSWORD.set(properties, PropertyDefinition.PASSWORD.getString(properties));
        }
    }

    private void updateAuthenticationToken(HostSpec hostSpec, Properties properties, Region region, String str) throws SQLException {
        Instant plus = Instant.now().plus(IAM_TOKEN_EXPIRATION.getInteger(properties), (TemporalUnit) ChronoUnit.SECONDS);
        String generateAuthenticationToken = generateAuthenticationToken(properties, hostSpec.getHost(), IamAuthUtils.getIamPort(StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(properties)) ? 0 : IAM_DEFAULT_PORT.getInteger(properties), hostSpec, this.pluginService.getDialect().getDefaultPort()), region, this.credentialsProviderFactory.getAwsCredentialsProvider(hostSpec.getHost(), region, properties));
        LOGGER.finest(() -> {
            return Messages.get("FederatedAuthPlugin.generatedNewIamToken", new Object[]{generateAuthenticationToken});
        });
        PropertyDefinition.PASSWORD.set(properties, generateAuthenticationToken);
        tokenCache.put(str, new TokenInfo(generateAuthenticationToken, plus));
    }

    private Region getRegion(String str, Properties properties) throws SQLException {
        String string = IAM_REGION.getString(properties);
        if (!StringUtils.isNullOrEmpty(string)) {
            return Region.of(string);
        }
        String rdsRegion = this.rdsUtils.getRdsRegion(str);
        if (StringUtils.isNullOrEmpty(rdsRegion)) {
            String str2 = Messages.get("FederatedAuthPlugin.unsupportedHostname", new Object[]{str});
            LOGGER.fine(str2);
            throw new SQLException(str2);
        }
        Optional<Region> findFirst = Region.regions().stream().filter(region -> {
            return region.id().equalsIgnoreCase(rdsRegion);
        }).findFirst();
        if (findFirst.isPresent()) {
            return findFirst.get();
        }
        String str3 = Messages.get("AwsSdk.unsupportedRegion", new Object[]{rdsRegion});
        LOGGER.fine(str3);
        throw new SQLException(str3);
    }

    String generateAuthenticationToken(Properties properties, String str, int i, Region region, AwsCredentialsProvider awsCredentialsProvider) {
        TelemetryContext openTelemetryContext = this.pluginService.getTelemetryFactory().openTelemetryContext(TELEMETRY_FETCH_TOKEN, TelemetryTraceLevel.NESTED);
        this.fetchTokenCounter.inc();
        try {
            try {
                String string = DB_USER.getString(properties);
                String generateAuthenticationToken = RdsUtilities.builder().credentialsProvider(awsCredentialsProvider).region(region).build().generateAuthenticationToken(builder -> {
                    builder.hostname(str).port(i).username(string);
                });
                openTelemetryContext.closeContext();
                return generateAuthenticationToken;
            } catch (Exception e) {
                openTelemetryContext.setSuccess(false);
                openTelemetryContext.setException(e);
                throw e;
            }
        } catch (Throwable th) {
            openTelemetryContext.closeContext();
            throw th;
        }
    }

    private String getCacheKey(String str, String str2, int i, Region region) {
        return String.format("%s:%s:%d:%s", region, str2, Integer.valueOf(i), str);
    }

    public static void clearCache() {
        tokenCache.clear();
    }

    static {
        PropertyDefinition.registerPluginProperties((Class<?>) FederatedAuthPlugin.class);
    }
}
