package io.trino.server.security.oauth2;

import com.google.common.base.Preconditions;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.KeyLengthException;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESDecrypter;
import com.nimbusds.jose.crypto.AESEncrypter;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.CompressionCodec;
import io.jsonwebtoken.CompressionException;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.JwtParser;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.security.oauth2.TokenPairSerializer;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.time.Clock;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;

/* loaded from: input_file:io/trino/server/security/oauth2/JweTokenSerializer.class */
public class JweTokenSerializer implements TokenPairSerializer {
    private static final Logger LOG = Logger.get(JweTokenSerializer.class);
    private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
    private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
    private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
    private static final String ACCESS_TOKEN_KEY = "access_token";
    private static final String EXPIRATION_TIME_KEY = "expiration_time";
    private static final String REFRESH_TOKEN_KEY = "refresh_token";
    private final OAuth2Client client;
    private final Clock clock;
    private final String issuer;
    private final String audience;
    private final Duration tokenExpiration;
    private final JwtParser parser;
    private final AESEncrypter jweEncrypter;
    private final AESDecrypter jweDecrypter;
    private final String principalField;

    public JweTokenSerializer(RefreshTokensConfig refreshTokensConfig, OAuth2Client oAuth2Client, String str, String str2, String str3, Clock clock, Duration duration) throws KeyLengthException, NoSuchAlgorithmException {
        SecretKey createKey = createKey((RefreshTokensConfig) Objects.requireNonNull(refreshTokensConfig, "config is null"));
        this.jweEncrypter = new AESEncrypter(createKey);
        this.jweDecrypter = new AESDecrypter(createKey);
        this.client = (OAuth2Client) Objects.requireNonNull(oAuth2Client, "client is null");
        this.issuer = (String) Objects.requireNonNull(str, "issuer is null");
        this.principalField = (String) Objects.requireNonNull(str3, "principalField is null");
        this.audience = (String) Objects.requireNonNull(str2, "issuer is null");
        this.clock = (Clock) Objects.requireNonNull(clock, "clock is null");
        this.tokenExpiration = (Duration) Objects.requireNonNull(duration, "tokenExpiration is null");
        this.parser = JwtUtil.newJwtParserBuilder().setClock(() -> {
            return Date.from(clock.instant());
        }).requireIssuer(this.issuer).requireAudience(this.audience).setCompressionCodecResolver(JweTokenSerializer::resolveCompressionCodec).build();
    }

    @Override // io.trino.server.security.oauth2.TokenPairSerializer
    public TokenPairSerializer.TokenPair deserialize(String str) {
        Objects.requireNonNull(str, "token is null");
        try {
            JWEObject parse = JWEObject.parse(str);
            parse.decrypt(this.jweDecrypter);
            Claims claims = (Claims) this.parser.parseClaimsJwt(parse.getPayload().toString()).getBody();
            return TokenPairSerializer.TokenPair.accessAndRefreshTokens((String) claims.get(ACCESS_TOKEN_KEY, String.class), (Date) claims.get(EXPIRATION_TIME_KEY, Date.class), (String) claims.get(REFRESH_TOKEN_KEY, String.class));
        } catch (ParseException e) {
            throw new IllegalArgumentException("Malformed jwt token", e);
        } catch (JOSEException e2) {
            throw new IllegalArgumentException("Decryption failed", e2);
        }
    }

    @Override // io.trino.server.security.oauth2.TokenPairSerializer
    public String serialize(TokenPairSerializer.TokenPair tokenPair) {
        Objects.requireNonNull(tokenPair, "tokenPair is null");
        Optional<Map<String, Object>> claims = this.client.getClaims(tokenPair.getAccessToken());
        if (claims.isEmpty()) {
            throw new IllegalArgumentException("Claims are missing");
        }
        Map<String, Object> map = claims.get();
        if (!map.containsKey(this.principalField)) {
            throw new IllegalArgumentException(String.format("%s field is missing", this.principalField));
        }
        JwtBuilder compressWith = JwtUtil.newJwtBuilder().setExpiration(Date.from(this.clock.instant().plusMillis(this.tokenExpiration.toMillis()))).claim(this.principalField, map.get(this.principalField).toString()).setAudience(this.audience).setIssuer(this.issuer).claim(ACCESS_TOKEN_KEY, tokenPair.getAccessToken()).claim(EXPIRATION_TIME_KEY, tokenPair.getExpiration()).compressWith(COMPRESSION_CODEC);
        if (tokenPair.getRefreshToken().isPresent()) {
            compressWith.claim(REFRESH_TOKEN_KEY, tokenPair.getRefreshToken().orElseThrow());
        } else {
            LOG.info("No refresh token has been issued, although coordinator expects one. Please check your IdP whether that is correct behaviour");
        }
        try {
            JWEObject jWEObject = new JWEObject(new JWEHeader(ALGORITHM, ENCRYPTION_METHOD), new Payload(compressWith.compact()));
            jWEObject.encrypt(this.jweEncrypter);
            return jWEObject.serialize();
        } catch (JOSEException e) {
            throw new IllegalStateException("Encryption failed", e);
        }
    }

    private static SecretKey createKey(RefreshTokensConfig refreshTokensConfig) throws NoSuchAlgorithmException {
        SecretKey secretKey = refreshTokensConfig.getSecretKey();
        if (secretKey != null) {
            return secretKey;
        }
        KeyGenerator keyGenerator = KeyGenerator.getInstance("AES");
        keyGenerator.init(256);
        return keyGenerator.generateKey();
    }

    private static CompressionCodec resolveCompressionCodec(Header header) throws CompressionException {
        if (header.getCompressionAlgorithm() == null) {
            return null;
        }
        Preconditions.checkState(header.getCompressionAlgorithm().equals(ZstdCodec.CODEC_NAME), "Unknown codec '%s' used for token compression", header.getCompressionAlgorithm());
        return COMPRESSION_CODEC;
    }
}
