package io.trino.server.security.oauth2;

import com.google.common.base.Strings;
import com.google.common.base.Verify;
import com.google.common.collect.Ordering;
import com.google.common.hash.Hashing;
import com.google.common.io.BaseEncoding;
import com.google.common.io.Resources;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.oauth2.OAuth2Client;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Date;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import javax.inject.Inject;

/* loaded from: input_file:io/trino/server/security/oauth2/OAuth2Service.class */
public class OAuth2Service {
    public static final String REDIRECT_URI = "redirect_uri";
    public static final String STATE = "state";
    public static final String NONCE = "nonce";
    public static final String OPENID_SCOPE = "openid";
    private static final String STATE_AUDIENCE_UI = "trino_oauth_ui";
    private static final String STATE_AUDIENCE_REST = "trino_oauth_rest";
    private static final String FAILURE_REPLACEMENT_TEXT = "<!-- ERROR_MESSAGE -->";
    private static final Random SECURE_RANDOM = new SecureRandom();
    private final OAuth2Client client;
    private final SigningKeyResolver signingKeyResolver;
    private final String successHtml = Resources.toString(Resources.getResource(getClass(), "/oauth2/success.html"), StandardCharsets.UTF_8);
    private final String failureHtml = Resources.toString(Resources.getResource(getClass(), "/oauth2/failure.html"), StandardCharsets.UTF_8);
    private final Set<String> scopes;
    private final TemporalAmount challengeTimeout;
    private final byte[] stateHmac;

    /* loaded from: input_file:io/trino/server/security/oauth2/OAuth2Service$OAuthChallenge.class */
    public static class OAuthChallenge {
        private final URI redirectUrl;
        private final Instant challengeExpiration;
        private final Optional<String> nonce;

        public OAuthChallenge(URI uri, Instant instant, Optional<String> optional) {
            this.redirectUrl = (URI) Objects.requireNonNull(uri, "redirectUrl is null");
            this.challengeExpiration = (Instant) Objects.requireNonNull(instant, "challengeExpiration is null");
            this.nonce = (Optional) Objects.requireNonNull(optional, "nonce is null");
        }

        public URI getRedirectUrl() {
            return this.redirectUrl;
        }

        public Instant getChallengeExpiration() {
            return this.challengeExpiration;
        }

        public Optional<String> getNonce() {
            return this.nonce;
        }
    }

    /* loaded from: input_file:io/trino/server/security/oauth2/OAuth2Service$OAuthResult.class */
    public static class OAuthResult {
        private final Optional<UUID> authId;
        private final String accessToken;
        private final Instant tokenExpiration;

        public OAuthResult(Optional<UUID> optional, String str, Instant instant) {
            this.authId = (Optional) Objects.requireNonNull(optional, "authId is null");
            this.accessToken = (String) Objects.requireNonNull(str, "accessToken is null");
            this.tokenExpiration = (Instant) Objects.requireNonNull(instant, "tokenExpiration is null");
        }

        public Optional<UUID> getAuthId() {
            return this.authId;
        }

        public String getAccessToken() {
            return this.accessToken;
        }

        public Instant getTokenExpiration() {
            return this.tokenExpiration;
        }
    }

    @Inject
    public OAuth2Service(OAuth2Client oAuth2Client, @ForOAuth2 SigningKeyResolver signingKeyResolver, OAuth2Config oAuth2Config) throws IOException {
        this.client = (OAuth2Client) Objects.requireNonNull(oAuth2Client, "client is null");
        this.signingKeyResolver = (SigningKeyResolver) Objects.requireNonNull(signingKeyResolver, "signingKeyResolver is null");
        Verify.verify(this.failureHtml.contains(FAILURE_REPLACEMENT_TEXT), "login.html does not contain the replacement text", new Object[0]);
        Objects.requireNonNull(oAuth2Config, "oauth2Config is null");
        this.scopes = oAuth2Config.getScopes();
        this.challengeTimeout = Duration.ofMillis(oAuth2Config.getChallengeTimeout().toMillis());
        this.stateHmac = (byte[]) oAuth2Config.getStateKey().map(str -> {
            return Hashing.sha256().hashString(str, StandardCharsets.UTF_8).asBytes();
        }).orElseGet(() -> {
            return secureRandomBytes(32);
        });
    }

    public OAuthChallenge startWebUiChallenge(URI uri) {
        Instant plus = Instant.now().plus(this.challengeTimeout);
        String compact = Jwts.builder().signWith(SignatureAlgorithm.HS256, this.stateHmac).setAudience(STATE_AUDIENCE_UI).setExpiration(Date.from(plus)).compact();
        Optional of = this.scopes.contains(OPENID_SCOPE) ? Optional.of(randomNonce()) : Optional.empty();
        return new OAuthChallenge(this.client.getAuthorizationUri(compact, uri, of.map(OAuth2Service::hashNonce)), plus, of);
    }

    public URI startRestChallenge(URI uri, UUID uuid) {
        return this.client.getAuthorizationUri(Jwts.builder().signWith(SignatureAlgorithm.HS256, this.stateHmac).setId(uuid.toString()).setAudience(STATE_AUDIENCE_REST).setExpiration(Date.from(Instant.now().plus(this.challengeTimeout))).compact(), uri, Optional.empty());
    }

    public OAuthResult finishChallenge(Optional<UUID> optional, String str, URI uri, Optional<String> optional2) throws ChallengeFailedException {
        Objects.requireNonNull(uri, "callbackUri is null");
        Objects.requireNonNull(optional, "authId is null");
        Objects.requireNonNull(str, "code is null");
        OAuth2Client.AccessToken accessToken = this.client.getAccessToken(str, uri);
        Claims claims = (Claims) Jwts.parser().setSigningKeyResolver(this.signingKeyResolver).parseClaimsJws(accessToken.getAccessToken()).getBody();
        validateNonce(optional, accessToken, optional2);
        return new OAuthResult(optional, accessToken.getAccessToken(), (Instant) accessToken.getValidUntil().map(instant -> {
            return (Instant) Ordering.natural().min(instant, claims.getExpiration().toInstant());
        }).orElse(claims.getExpiration().toInstant()));
    }

    public Optional<UUID> getAuthId(String str) throws ChallengeFailedException {
        Claims parseState = parseState(str);
        if (STATE_AUDIENCE_UI.equals(parseState.getAudience())) {
            return Optional.empty();
        }
        if (!STATE_AUDIENCE_REST.equals(parseState.getAudience())) {
            throw new ChallengeFailedException("Unexpected state audience");
        }
        try {
            return Optional.of(UUID.fromString(parseState.getId()));
        } catch (IllegalArgumentException e) {
            throw new ChallengeFailedException("State is does not contain an auth ID");
        }
    }

    private Claims parseState(String str) throws ChallengeFailedException {
        try {
            return (Claims) Jwts.parser().setSigningKey(this.stateHmac).parseClaimsJws(str).getBody();
        } catch (RuntimeException e) {
            throw new ChallengeFailedException("State validation failed", e);
        }
    }

    public Jws<Claims> parseClaimsJws(String str) {
        return Jwts.parser().setSigningKeyResolver(this.signingKeyResolver).parseClaimsJws(str);
    }

    public String getSuccessHtml() {
        return this.successHtml;
    }

    public String getCallbackErrorHtml(String str) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, getOAuth2ErrorMessage(str));
    }

    public String getInternalFailureHtml(String str) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, Strings.nullToEmpty(str));
    }

    private void validateNonce(Optional<UUID> optional, OAuth2Client.AccessToken accessToken, Optional<String> optional2) throws ChallengeFailedException {
        if (optional.isPresent()) {
            return;
        }
        if (optional2.isPresent() != accessToken.getIdToken().isPresent()) {
            throw new ChallengeFailedException("Cannot validate nonce parameter");
        }
        optional2.ifPresent(str -> {
            Jwts.parser().setSigningKeyResolver(this.signingKeyResolver).require(NONCE, hashNonce(str)).parseClaimsJws(accessToken.getIdToken().get());
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static byte[] secureRandomBytes(int i) {
        byte[] bArr = new byte[i];
        SECURE_RANDOM.nextBytes(bArr);
        return bArr;
    }

    private static String getOAuth2ErrorMessage(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -2054838772:
                if (str.equals("server_error")) {
                    z = 2;
                    break;
                }
                break;
            case -1307356897:
                if (str.equals("temporarily_unavailable")) {
                    z = 3;
                    break;
                }
                break;
            case -444618026:
                if (str.equals("access_denied")) {
                    z = false;
                    break;
                }
                break;
            case 1330404726:
                if (str.equals("unauthorized_client")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return "OAuth2 server denied the login";
            case true:
                return "OAuth2 server does not allow request from this Trino server";
            case true:
                return "OAuth2 server had a failure";
            case true:
                return "OAuth2 server is temporarily unavailable";
            default:
                return "OAuth2 unknown error code: " + str;
        }
    }

    private static String randomNonce() {
        return BaseEncoding.base64Url().encode(secureRandomBytes(18));
    }

    private static String hashNonce(String str) {
        return Hashing.sha256().hashString(str, StandardCharsets.UTF_8).toString();
    }
}
