/*
 * Decompiled with CFR 0.152.
 */
package org.reaktivity.nukleus.oauth.internal;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import org.agrona.LangUtil;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.lang.JoseException;
import org.reaktivity.nukleus.internal.CopyOnWriteHashMap;

public class OAuthRealms {
    private static final List<String> EMPTY_STRING_LIST = Collections.emptyList();
    private static final String SCOPE_CLAIM = "scope";
    private static final Long NO_AUTHORIZATION = 0L;
    private static final int MAX_REALMS = 16;
    private static final long REALM_MASK = -281474976710656L;
    private final Map<String, OAuthRealm> realmsByName = new CopyOnWriteHashMap();
    private int nextRealmBit = 0;
    private final Map<String, JsonWebKey> keysByKid;

    public OAuthRealms() {
        this(Collections.emptyMap());
    }

    public OAuthRealms(Map<String, JsonWebKey> keysByKid) {
        this.keysByKid = keysByKid;
    }

    public long resolve(String realmName, String issuerName, String audienceName, List<String> scopeNames) {
        long authorization = NO_AUTHORIZATION;
        if (this.nextRealmBit < 16) {
            OAuthRealm realm = this.realmsByName.computeIfAbsent(realmName, x$0 -> new OAuthRealm((String)x$0));
            authorization = realm.resolve(issuerName, audienceName, scopeNames);
        }
        return authorization;
    }

    public long resolve(String realmName) {
        return this.resolve(realmName, null, null, EMPTY_STRING_LIST);
    }

    public long lookup(JsonWebSignature verified) {
        OAuthRealm realm = this.realmsByName.get(verified.getKeyIdHeaderValue());
        long authorization = NO_AUTHORIZATION;
        if (realm != null) {
            try {
                JwtClaims claims = JwtClaims.parse((String)verified.getPayload());
                Object issuerClaim = claims.getClaimValue("iss");
                Object audienceClaim = claims.getClaimValue("aud");
                Object scopeClaim = claims.getClaimValue(SCOPE_CLAIM);
                String issuerName = issuerClaim != null ? issuerClaim.toString() : null;
                List<String> audienceNames = EMPTY_STRING_LIST;
                if (audienceClaim instanceof List) {
                    audienceNames = (List<String>)audienceClaim;
                } else if (audienceClaim instanceof String) {
                    audienceNames = Collections.singletonList((String)audienceClaim);
                }
                List<String> scopeNames = EMPTY_STRING_LIST;
                if (scopeClaim != null) {
                    scopeNames = Arrays.asList(scopeClaim.toString().split("\\s+"));
                }
                authorization = realm.lookup(issuerName, audienceNames, scopeNames);
            }
            catch (InvalidJwtException | JoseException throwable) {
                // empty catch block
            }
        }
        return authorization;
    }

    public boolean unresolve(long authorization) {
        long realmId = authorization & 0xFFFF000000000000L;
        Collection<OAuthRealm> realms = this.realmsByName.values();
        OAuthRealm realm = realms.stream().filter(rs -> ((OAuthRealm)rs).unresolve(realmId)).findFirst().orElse(null);
        realms.removeIf(rec$ -> ((OAuthRealm)rec$).isEmpty());
        return Long.bitCount(realmId) <= 1 && realm != null;
    }

    public JsonWebKey lookupKey(String kid) {
        return this.keysByKid.get(kid);
    }

    static Map<String, JsonWebKey> parseKeyMap(Path keyFile) {
        Map<String, JsonWebKey> keysByKid = Collections.emptyMap();
        if (Files.exists(keyFile, new LinkOption[0])) {
            try {
                byte[] rawKeys = Files.readAllBytes(keyFile);
                String keysAsJwkSet = new String(rawKeys, StandardCharsets.UTF_8);
                keysByKid = OAuthRealms.toKeyMap(keysAsJwkSet);
            }
            catch (IOException ex) {
                LangUtil.rethrowUnchecked((Throwable)ex);
            }
        }
        return keysByKid;
    }

    private static Map<String, JsonWebKey> toKeyMap(String keysAsJwkSet) {
        Map<String, JsonWebKey> keysByKid = Collections.emptyMap();
        try {
            JsonWebKeySet keys = new JsonWebKeySet(keysAsJwkSet);
            keysByKid = new LinkedHashMap<String, JsonWebKey>();
            for (JsonWebKey key : keys.getJsonWebKeys()) {
                String kid = key.getKeyId();
                if (kid == null) {
                    throw new IllegalArgumentException("Key without kid");
                }
                if (key.getAlgorithm() == null) {
                    throw new IllegalArgumentException("Key without alg");
                }
                JsonWebKey existingKey = keysByKid.putIfAbsent(kid, key);
                if (existingKey == null) continue;
                throw new IllegalArgumentException("Key with duplicate kid");
            }
            keysByKid = Collections.unmodifiableMap(keysByKid);
        }
        catch (JoseException ex) {
            LangUtil.rethrowUnchecked((Throwable)ex);
        }
        return keysByKid;
    }

    private final class OAuthRealm {
        private static final int MAX_SCOPES = 48;
        private final List<OAuthRealmInfo> realmInfos = new CopyOnWriteArrayList<OAuthRealmInfo>();
        private final String realmName;
        private int nextScopeBit;

        private OAuthRealm(String realmName) {
            assert (OAuthRealms.this.nextRealmBit < 16);
            this.realmName = realmName;
        }

        private long resolve(String issuerName, String audienceName, List<String> scopeNames) {
            int scopeNamesSize;
            long authorization = NO_AUTHORIZATION;
            assert (OAuthRealms.this.nextRealmBit < 16);
            int n = scopeNamesSize = scopeNames != null ? scopeNames.size() : 0;
            if (this.nextScopeBit + scopeNamesSize < 48) {
                OAuthRealmInfo realmInfo = this.realmInfos.stream().filter(r -> ((OAuthRealmInfo)r).containsClaims(issuerName, audienceName)).findFirst().orElseGet(() -> this.newRealmInfo(issuerName, audienceName));
                authorization = realmInfo.realmId;
                for (int i = 0; i < scopeNamesSize; ++i) {
                    authorization |= realmInfo.supplyScopeBit(scopeNames.get(i));
                }
            }
            return authorization;
        }

        private long lookup(String issuerName, List<String> audienceNames, List<String> scopeNames) {
            OAuthRealmInfo realmInfo = this.realmInfos.stream().filter(r -> ((OAuthRealmInfo)r).containsClaims(issuerName, audienceNames)).findFirst().orElse(null);
            long authorization = NO_AUTHORIZATION;
            if (realmInfo != null) {
                authorization = realmInfo.realmId;
                for (int i = 0; i < scopeNames.size(); ++i) {
                    authorization |= realmInfo.scopeBit(scopeNames.get(i));
                }
            }
            return authorization;
        }

        private boolean unresolve(long realmId) {
            return this.realmInfos.removeIf(i -> ((OAuthRealmInfo)i).realmId == realmId);
        }

        private boolean isEmpty() {
            return this.realmInfos.isEmpty();
        }

        private OAuthRealmInfo newRealmInfo(String issuerName, String audienceName) {
            OAuthRealmInfo realmInfo = new OAuthRealmInfo(1L << OAuthRealms.this.nextRealmBit++ << 48, issuerName, audienceName);
            this.realmInfos.add(realmInfo);
            return realmInfo;
        }

        public String toString() {
            return String.format("Realm name: \"%s\",\tRealm info: %s\n", this.realmName, this.realmInfos);
        }

        private final class OAuthRealmInfo {
            private final Map<String, Long> scopeBitsByName = new CopyOnWriteHashMap();
            private final long realmId;
            private final Claims requiredClaims;

            private OAuthRealmInfo(long realmId, String issuerName, String audienceName) {
                this.realmId = realmId;
                this.requiredClaims = new Claims(issuerName, audienceName);
            }

            private long scopeBit(String scopeName) {
                return this.scopeBitsByName.getOrDefault(scopeName, 0L);
            }

            private long supplyScopeBit(String scopeName) {
                return this.scopeBitsByName.computeIfAbsent(scopeName, this::assignScopeBit);
            }

            private boolean containsClaims(String issuerName, String audienceName) {
                return this.requiredClaims.containsClaims(issuerName, audienceName);
            }

            private boolean containsClaims(String issuerName, List<String> audienceNames) {
                return this.requiredClaims.containsClaims(issuerName, audienceNames);
            }

            private long assignScopeBit(String scopeName) {
                assert (OAuthRealm.this.nextScopeBit < 48);
                return 1L << OAuthRealm.this.nextScopeBit++;
            }

            public String toString() {
                return String.format("Info: realm id=%d, claims=[%s], scope bits=%s", this.realmId, this.requiredClaims, this.scopeBitsByName);
            }

            private final class Claims {
                final String issuerName;
                final String audienceName;

                private Claims(String issuerName, String audienceName) {
                    this.issuerName = issuerName;
                    this.audienceName = audienceName;
                }

                private boolean containsClaims(String issuerName, List<String> audienceNames) {
                    return !(this.issuerName != null && !Objects.equals(this.issuerName, issuerName) || this.audienceName != null && !audienceNames.contains(this.audienceName));
                }

                private boolean containsClaims(String issuerName, String audienceName) {
                    return !(this.issuerName != null && !Objects.equals(this.issuerName, issuerName) || this.audienceName != null && !Objects.equals(this.audienceName, audienceName));
                }

                public String toString() {
                    return String.format("issuer=\"%s\", audience=\"%s\"", this.issuerName, this.audienceName);
                }
            }
        }
    }
}

