/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.crypto;

import at.favre.lib.crypto.HKDF;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import net.luminis.quic.DecryptionException;
import net.luminis.quic.QuicRuntimeException;
import net.luminis.quic.Role;
import net.luminis.quic.Version;
import net.luminis.quic.log.Logger;
import net.luminis.tls.TrafficSecrets;

public class Keys {
    public static final Charset ISO_8859_1 = Charset.forName("ISO-8859-1");
    private final Role nodeRole;
    private final Logger log;
    private final Version quicVersion;
    private byte[] trafficSecret;
    private byte[] newApplicationTrafficSecret;
    protected byte[] writeKey;
    protected byte[] newKey;
    protected byte[] writeIV;
    protected byte[] newIV;
    protected byte[] hp;
    protected Cipher hpCipher;
    protected SecretKeySpec writeKeySpec;
    protected SecretKeySpec newWriteKeySpec;
    protected Cipher writeCipher;
    private int keyUpdateCounter = 0;
    private boolean possibleKeyUpdateInProgresss = false;
    private volatile Keys peerKeys;

    public Keys(Version quicVersion, Role nodeRole, Logger log) {
        this.nodeRole = nodeRole;
        this.log = log;
        this.quicVersion = quicVersion;
    }

    public Keys(Version quicVersion, byte[] initialSecret, Role nodeRole, Logger log) {
        this.nodeRole = nodeRole;
        this.log = log;
        this.quicVersion = quicVersion;
        byte[] initialNodeSecret = Keys.hkdfExpandLabel(quicVersion, initialSecret, nodeRole == Role.Client ? "client in" : "server in", "", (short)32);
        log.secret(nodeRole + " initial secret", initialNodeSecret);
        this.computeKeys(initialNodeSecret, true, true);
    }

    public synchronized void computeZeroRttKeys(TrafficSecrets secrets) {
        byte[] earlySecret = secrets.getClientEarlyTrafficSecret();
        this.computeKeys(earlySecret, true, true);
    }

    public synchronized void computeHandshakeKeys(TrafficSecrets secrets) {
        if (this.nodeRole == Role.Client) {
            this.trafficSecret = secrets.getClientHandshakeTrafficSecret();
            this.log.secret("ClientHandshakeTrafficSecret: ", this.trafficSecret);
            this.computeKeys(this.trafficSecret, true, true);
        }
        if (this.nodeRole == Role.Server) {
            this.trafficSecret = secrets.getServerHandshakeTrafficSecret();
            this.log.secret("ServerHandshakeTrafficSecret: ", this.trafficSecret);
            this.computeKeys(this.trafficSecret, true, true);
        }
    }

    public synchronized void computeApplicationKeys(TrafficSecrets secrets) {
        if (this.nodeRole == Role.Client) {
            this.trafficSecret = secrets.getClientApplicationTrafficSecret();
            this.log.secret("ClientApplicationTrafficSecret: ", this.trafficSecret);
            this.computeKeys(this.trafficSecret, true, true);
        }
        if (this.nodeRole == Role.Server) {
            this.trafficSecret = secrets.getServerApplicationTrafficSecret();
            this.log.secret("Got new serverApplicationTrafficSecret from TLS (recomputing secrets): ", this.trafficSecret);
            this.computeKeys(this.trafficSecret, true, true);
        }
    }

    public synchronized void computeKeyUpdate(boolean selfInitiated) {
        this.newApplicationTrafficSecret = Keys.hkdfExpandLabel(this.quicVersion, this.trafficSecret, "quic ku", "", (short)32);
        this.log.secret("Updated ApplicationTrafficSecret (" + (selfInitiated ? "self" : "peer") + "): ", this.newApplicationTrafficSecret);
        this.computeKeys(this.newApplicationTrafficSecret, false, selfInitiated);
        if (selfInitiated) {
            this.trafficSecret = this.newApplicationTrafficSecret;
            ++this.keyUpdateCounter;
            this.newApplicationTrafficSecret = null;
        }
    }

    public synchronized void confirmKeyUpdateIfInProgress() {
        if (this.possibleKeyUpdateInProgresss) {
            this.log.info("Installing updated keys (initiated by peer)");
            this.trafficSecret = this.newApplicationTrafficSecret;
            this.writeKey = this.newKey;
            this.writeKeySpec = null;
            this.writeIV = this.newIV;
            ++this.keyUpdateCounter;
            this.newApplicationTrafficSecret = null;
            this.possibleKeyUpdateInProgresss = false;
            this.newKey = null;
            this.newIV = null;
            this.checkPeerKeys();
        }
    }

    private void checkPeerKeys() {
        if (this.peerKeys.keyUpdateCounter < this.keyUpdateCounter) {
            this.log.debug("Keys out of sync; updating keys for peer");
            this.peerKeys.computeKeyUpdate(true);
        }
    }

    public synchronized void cancelKeyUpdateIfInProgress() {
        if (this.possibleKeyUpdateInProgresss) {
            this.log.info("Discarding updated keys (initiated by peer)");
            this.newApplicationTrafficSecret = null;
            this.possibleKeyUpdateInProgresss = false;
            this.newKey = null;
            this.newIV = null;
        }
    }

    private void computeKeys(byte[] secret, boolean includeHP, boolean replaceKeys) {
        String prefix = "quic ";
        byte[] key = Keys.hkdfExpandLabel(this.quicVersion, secret, prefix + "key", "", this.getKeyLength());
        if (replaceKeys) {
            this.writeKey = key;
            this.writeKeySpec = null;
        } else {
            this.newKey = key;
            this.newWriteKeySpec = null;
        }
        this.log.secret(this.nodeRole + " key", key);
        byte[] iv = Keys.hkdfExpandLabel(this.quicVersion, secret, prefix + "iv", "", (short)12);
        if (replaceKeys) {
            this.writeIV = iv;
        } else {
            this.newIV = iv;
        }
        this.log.secret(this.nodeRole + " iv", iv);
        if (includeHP) {
            this.hp = Keys.hkdfExpandLabel(this.quicVersion, secret, prefix + "hp", "", this.getKeyLength());
            this.log.secret(this.nodeRole + " hp", this.hp);
        }
    }

    protected short getKeyLength() {
        return 16;
    }

    static byte[] hkdfExpandLabel(Version quicVersion, byte[] secret, String label, String context, short length) {
        byte[] prefix = "tls13 ".getBytes(ISO_8859_1);
        ByteBuffer hkdfLabel = ByteBuffer.allocate(3 + prefix.length + label.getBytes(ISO_8859_1).length + 1 + context.getBytes(ISO_8859_1).length);
        hkdfLabel.putShort(length);
        hkdfLabel.put((byte)(prefix.length + label.getBytes().length));
        hkdfLabel.put(prefix);
        hkdfLabel.put(label.getBytes(ISO_8859_1));
        hkdfLabel.put((byte)context.getBytes(ISO_8859_1).length);
        hkdfLabel.put(context.getBytes(ISO_8859_1));
        HKDF hkdf = HKDF.fromHmacSha256();
        return hkdf.expand(secret, hkdfLabel.array(), length);
    }

    public byte[] getTrafficSecret() {
        return this.trafficSecret;
    }

    public byte[] getWriteKey() {
        if (this.possibleKeyUpdateInProgresss) {
            return this.newKey;
        }
        return this.writeKey;
    }

    public byte[] getWriteIV() {
        if (this.possibleKeyUpdateInProgresss) {
            return this.newIV;
        }
        return this.writeIV;
    }

    public byte[] getHp() {
        return this.hp;
    }

    public Cipher getHeaderProtectionCipher() {
        if (this.hpCipher == null) {
            try {
                this.hpCipher = Cipher.getInstance("AES/ECB/NoPadding");
                SecretKeySpec keySpec = new SecretKeySpec(this.getHp(), "AES");
                this.hpCipher.init(1, keySpec);
            }
            catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
                throw new QuicRuntimeException(e);
            }
            catch (InvalidKeyException e) {
                throw new RuntimeException();
            }
        }
        return this.hpCipher;
    }

    public SecretKeySpec getWriteKeySpec() {
        if (this.possibleKeyUpdateInProgresss) {
            if (this.newWriteKeySpec == null) {
                this.newWriteKeySpec = new SecretKeySpec(this.newKey, "AES");
            }
            return this.newWriteKeySpec;
        }
        if (this.writeKeySpec == null) {
            this.writeKeySpec = new SecretKeySpec(this.writeKey, "AES");
        }
        return this.writeKeySpec;
    }

    public Cipher getWriteCipher() {
        if (this.writeCipher == null) {
            try {
                String AES_GCM_NOPADDING = "AES/GCM/NoPadding";
                this.writeCipher = Cipher.getInstance(AES_GCM_NOPADDING);
            }
            catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
                throw new QuicRuntimeException(e);
            }
        }
        return this.writeCipher;
    }

    public byte[] aeadEncrypt(byte[] associatedData, byte[] message, byte[] nonce) {
        Cipher aeadCipher = this.getWriteCipher();
        SecretKeySpec secretKey = this.getWriteKeySpec();
        try {
            GCMParameterSpec parameterSpec = new GCMParameterSpec(128, nonce);
            aeadCipher.init(1, (Key)secretKey, parameterSpec);
            aeadCipher.updateAAD(associatedData);
            return aeadCipher.doFinal(message);
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
            throw new RuntimeException();
        }
    }

    public byte[] aeadDecrypt(byte[] associatedData, byte[] message, byte[] nonce) throws DecryptionException {
        SecretKeySpec secretKey = this.getWriteKeySpec();
        Cipher aeadCipher = this.getWriteCipher();
        try {
            GCMParameterSpec parameterSpec = new GCMParameterSpec(128, nonce);
            aeadCipher.init(2, (Key)secretKey, parameterSpec);
            aeadCipher.updateAAD(associatedData);
            return aeadCipher.doFinal(message);
        }
        catch (AEADBadTagException decryptError) {
            throw new DecryptionException();
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
            throw new RuntimeException();
        }
    }

    public byte[] createHeaderProtectionMask(byte[] sample) {
        byte[] mask;
        Cipher hpCipher = this.getHeaderProtectionCipher();
        try {
            mask = hpCipher.doFinal(sample);
        }
        catch (BadPaddingException | IllegalBlockSizeException e) {
            throw new RuntimeException();
        }
        return mask;
    }

    public short getKeyPhase() {
        return (short)(this.keyUpdateCounter % 2);
    }

    public void checkKeyPhase(short keyPhaseBit) {
        if (this.keyUpdateCounter % 2 != keyPhaseBit) {
            if (this.newKey == null) {
                this.computeKeyUpdate(false);
                this.log.secret("Computed new (updated) key", this.newKey);
                this.log.secret("Computed new (updated) iv", this.newIV);
            }
            this.log.info("Received key phase does not match current => possible key update in progress");
            this.possibleKeyUpdateInProgresss = true;
        }
    }

    void setPeerKeys(Keys peerKeys) {
        this.peerKeys = peerKeys;
    }
}

