/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.tls;

import at.favre.lib.crypto.HKDF;
import at.favre.lib.crypto.HkdfMacFactory;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.XECPublicKey;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.KeyAgreement;
import javax.crypto.Mac;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import net.luminis.tls.Logger;
import net.luminis.tls.TlsConstants;
import net.luminis.tls.TranscriptHash;
import net.luminis.tls.util.ByteUtils;

public class TlsState {
    private static final Charset ISO_8859_1 = Charset.forName("ISO-8859-1");
    private static String labelPrefix = "tls13 ";
    private final MessageDigest hashFunction;
    private final HKDF hkdf;
    private final byte[] emptyHash;
    private final short authenticationTagLength = (short)16;
    private final short keyLength = (short)16;
    private final short hashLength = (short)32;
    private final short iv_length = (short)12;
    private boolean pskSelected;
    private PublicKey serverSharedKey;
    private PrivateKey clientPrivateKey;
    private final byte[] psk;
    private byte[] earlySecret;
    private byte[] binderKey;
    private byte[] resumptionMasterSecret;
    private byte[] serverHandshakeTrafficSecret;
    private byte[] clientEarlyTrafficSecret;
    private byte[] clientHandshakeTrafficSecret;
    private byte[] handshakeSecret;
    private byte[] clientApplicationTrafficSecret;
    private byte[] serverApplicationTrafficSecret;
    private byte[] serverKey;
    private byte[] serverIv;
    private byte[] clientKey;
    private byte[] clientIv;
    private int serverRecordCount = 0;
    private int clientRecordCount = 0;
    private final TranscriptHash transcriptHash;
    private byte[] sharedSecret;

    public TlsState(TranscriptHash transcriptHash, byte[] psk) {
        this.psk = psk;
        this.transcriptHash = transcriptHash;
        String hashAlgorithm = "SHA-256";
        try {
            this.hashFunction = MessageDigest.getInstance(hashAlgorithm);
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("Missing " + hashAlgorithm + " support");
        }
        String macAlgorithm = "HmacSHA256";
        this.hkdf = HKDF.from(new HkdfMacFactory.Default(macAlgorithm, null));
        this.emptyHash = this.hashFunction.digest(new byte[0]);
        Logger.debug("Empty hash: " + ByteUtils.bytesToHex(this.emptyHash));
        if (psk == null) {
            psk = new byte[32];
        }
        this.computeEarlySecret(psk);
    }

    public TlsState(TranscriptHash transcriptHash) {
        this(transcriptHash, null);
    }

    private byte[] computeEarlySecret(byte[] ikm) {
        byte[] zeroSalt = new byte[32];
        this.earlySecret = this.hkdf.extract(zeroSalt, ikm);
        Logger.debug("Early secret: " + ByteUtils.bytesToHex(this.earlySecret));
        this.binderKey = this.hkdfExpandLabel(this.earlySecret, "res binder", this.emptyHash, (short)32);
        Logger.debug("Binder key: " + ByteUtils.bytesToHex(this.binderKey));
        return this.earlySecret;
    }

    public byte[] computePskBinder(byte[] partialClientHello) {
        String macAlgorithmName = "HmacSHA256";
        try {
            this.hashFunction.reset();
            this.hashFunction.update(partialClientHello);
            byte[] hash = this.hashFunction.digest();
            byte[] finishedKey = this.hkdfExpandLabel(this.binderKey, "finished", "", (short)32);
            SecretKeySpec hmacKey = new SecretKeySpec(finishedKey, macAlgorithmName);
            Mac hmacAlgorithm = Mac.getInstance(macAlgorithmName);
            hmacAlgorithm.init(hmacKey);
            hmacAlgorithm.update(hash);
            byte[] hmac = hmacAlgorithm.doFinal();
            return hmac;
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("Missing " + macAlgorithmName + " support");
        }
        catch (InvalidKeyException e) {
            throw new RuntimeException();
        }
    }

    public void computeSharedSecret() {
        try {
            KeyAgreement keyAgreement;
            if (this.serverSharedKey instanceof ECPublicKey) {
                keyAgreement = KeyAgreement.getInstance("ECDH");
            } else if (this.serverSharedKey instanceof XECPublicKey) {
                keyAgreement = KeyAgreement.getInstance("XDH");
            } else {
                throw new RuntimeException("Unsupported key type");
            }
            keyAgreement.init(this.clientPrivateKey);
            keyAgreement.doPhase(this.serverSharedKey, true);
            this.sharedSecret = keyAgreement.generateSecret();
            Logger.debug("Shared key: " + ByteUtils.bytesToHex(this.sharedSecret));
        }
        catch (InvalidKeyException | NoSuchAlgorithmException e) {
            throw new RuntimeException("Unsupported crypto: " + e);
        }
    }

    public void computeEarlyTrafficSecret() {
        byte[] clientHelloHash = this.transcriptHash.getHash(TlsConstants.HandshakeType.client_hello);
        this.clientEarlyTrafficSecret = this.hkdfExpandLabel(this.earlySecret, "c e traffic", clientHelloHash, (short)32);
    }

    public void computeHandshakeSecrets() {
        byte[] derivedSecret = this.hkdfExpandLabel(this.earlySecret, "derived", this.emptyHash, (short)32);
        Logger.debug("Derived secret: " + ByteUtils.bytesToHex(derivedSecret));
        this.handshakeSecret = this.hkdf.extract(derivedSecret, this.sharedSecret);
        Logger.debug("Handshake secret: " + ByteUtils.bytesToHex(this.handshakeSecret));
        byte[] handshakeHash = this.transcriptHash.getHash(TlsConstants.HandshakeType.server_hello);
        this.clientHandshakeTrafficSecret = this.hkdfExpandLabel(this.handshakeSecret, "c hs traffic", handshakeHash, (short)32);
        Logger.debug("Client handshake traffic secret: " + ByteUtils.bytesToHex(this.clientHandshakeTrafficSecret));
        this.serverHandshakeTrafficSecret = this.hkdfExpandLabel(this.handshakeSecret, "s hs traffic", handshakeHash, (short)32);
        Logger.debug("Server handshake traffic secret: " + ByteUtils.bytesToHex(this.serverHandshakeTrafficSecret));
        byte[] clientHandshakeKey = this.hkdfExpandLabel(this.clientHandshakeTrafficSecret, "key", "", (short)16);
        Logger.debug("Client handshake key: " + ByteUtils.bytesToHex(clientHandshakeKey));
        this.clientKey = clientHandshakeKey;
        byte[] serverHandshakeKey = this.hkdfExpandLabel(this.serverHandshakeTrafficSecret, "key", "", (short)16);
        Logger.debug("Server handshake key: " + ByteUtils.bytesToHex(serverHandshakeKey));
        this.serverKey = serverHandshakeKey;
        byte[] clientHandshakeIV = this.hkdfExpandLabel(this.clientHandshakeTrafficSecret, "iv", "", (short)12);
        Logger.debug("Client handshake iv: " + ByteUtils.bytesToHex(clientHandshakeIV));
        this.clientIv = clientHandshakeIV;
        byte[] serverHandshakeIV = this.hkdfExpandLabel(this.serverHandshakeTrafficSecret, "iv", "", (short)12);
        Logger.debug("Server handshake iv: " + ByteUtils.bytesToHex(serverHandshakeIV));
        this.serverIv = serverHandshakeIV;
    }

    public void computeApplicationSecrets() {
        this.computeApplicationSecrets(this.handshakeSecret);
        this.serverRecordCount = 0;
        this.clientRecordCount = 0;
    }

    void computeApplicationSecrets(byte[] handshakeSecret) {
        byte[] serverFinishedHash = this.transcriptHash.getServerHash(TlsConstants.HandshakeType.finished);
        byte[] clientFinishedHash = this.transcriptHash.getClientHash(TlsConstants.HandshakeType.finished);
        byte[] derivedSecret = this.hkdfExpandLabel(handshakeSecret, "derived", this.emptyHash, (short)32);
        Logger.debug("Derived secret: " + ByteUtils.bytesToHex(derivedSecret));
        byte[] zeroKey = new byte[32];
        byte[] masterSecret = this.hkdf.extract(derivedSecret, zeroKey);
        Logger.debug("Master secret: " + ByteUtils.bytesToHex(masterSecret));
        this.clientApplicationTrafficSecret = this.hkdfExpandLabel(masterSecret, "c ap traffic", serverFinishedHash, (short)32);
        Logger.debug("Client application traffic secret: " + ByteUtils.bytesToHex(this.clientApplicationTrafficSecret));
        this.serverApplicationTrafficSecret = this.hkdfExpandLabel(masterSecret, "s ap traffic", serverFinishedHash, (short)32);
        Logger.debug("Server application traffic secret: " + ByteUtils.bytesToHex(this.serverApplicationTrafficSecret));
        this.resumptionMasterSecret = this.hkdfExpandLabel(masterSecret, "res master", clientFinishedHash, (short)32);
        Logger.debug("Resumption master secret: " + ByteUtils.bytesToHex(this.resumptionMasterSecret));
        byte[] clientApplicationKey = this.hkdfExpandLabel(this.clientApplicationTrafficSecret, "key", "", (short)16);
        Logger.debug("Client application key: " + ByteUtils.bytesToHex(clientApplicationKey));
        this.clientKey = clientApplicationKey;
        byte[] serverApplicationKey = this.hkdfExpandLabel(this.serverApplicationTrafficSecret, "key", "", (short)16);
        Logger.debug("Server application key: " + ByteUtils.bytesToHex(serverApplicationKey));
        this.serverKey = serverApplicationKey;
        byte[] clientApplicationIv = this.hkdfExpandLabel(this.clientApplicationTrafficSecret, "iv", "", (short)12);
        Logger.debug("Client application iv: " + ByteUtils.bytesToHex(clientApplicationIv));
        this.clientIv = clientApplicationIv;
        byte[] serverApplicationIv = this.hkdfExpandLabel(this.serverApplicationTrafficSecret, "iv", "", (short)12);
        Logger.debug("Server application iv: " + ByteUtils.bytesToHex(serverApplicationIv));
        this.serverIv = serverApplicationIv;
    }

    byte[] computePSK(byte[] ticketNonce) {
        byte[] psk = this.hkdfExpandLabel(this.resumptionMasterSecret, "resumption", ticketNonce, (short)32);
        return psk;
    }

    public byte[] hkdfExpandLabel(byte[] secret, String label, String context, short length) {
        return this.hkdfExpandLabel(secret, label, context.getBytes(ISO_8859_1), length);
    }

    byte[] hkdfExpandLabel(byte[] secret, String label, byte[] context, short length) {
        ByteBuffer hkdfLabel = ByteBuffer.allocate(3 + labelPrefix.length() + label.getBytes(ISO_8859_1).length + 1 + context.length);
        hkdfLabel.putShort(length);
        hkdfLabel.put((byte)(labelPrefix.length() + label.getBytes().length));
        hkdfLabel.put(labelPrefix.getBytes(ISO_8859_1));
        hkdfLabel.put(label.getBytes(ISO_8859_1));
        hkdfLabel.put((byte)context.length);
        hkdfLabel.put(context);
        return this.hkdf.expand(secret, hkdfLabel.array(), length);
    }

    public byte[] decrypt(byte[] recordHeader, byte[] payload) {
        int recordSize = (recordHeader[3] & 0xFF) << 8 | recordHeader[4] & 0xFF;
        Logger.debug("Payload length: " + payload.length + " bytes, size in record: " + recordSize);
        byte[] encryptedData = new byte[recordSize - 16];
        byte[] authTag = new byte[16];
        System.arraycopy(payload, 0, encryptedData, 0, encryptedData.length);
        System.arraycopy(payload, 0 + recordSize - 16, authTag, 0, authTag.length);
        Logger.debug("Record data: " + ByteUtils.bytesToHex(recordHeader));
        Logger.debug("Encrypted data: " + ByteUtils.bytesToHex(encryptedData, Math.min(8, encryptedData.length)) + "..." + ByteUtils.bytesToHex(encryptedData, Math.max(encryptedData.length - 8, 0), Math.min(8, encryptedData.length)));
        Logger.debug("Auth tag: " + ByteUtils.bytesToHex(authTag));
        byte[] wrapped = this.decryptPayload(payload, recordHeader, this.serverRecordCount);
        ++this.serverRecordCount;
        Logger.debug("Decrypted data (" + wrapped.length + "): " + ByteUtils.bytesToHex(wrapped, Math.min(8, wrapped.length)) + "..." + ByteUtils.bytesToHex(wrapped, Math.max(wrapped.length - 8, 0), Math.min(8, wrapped.length)));
        return wrapped;
    }

    byte[] decryptPayload(byte[] message, byte[] associatedData, int recordNumber) {
        ByteBuffer nonceInput = ByteBuffer.allocate(12);
        nonceInput.putInt(0);
        nonceInput.putLong(recordNumber);
        byte[] nonce = new byte[12];
        int i = 0;
        for (byte b : nonceInput.array()) {
            nonce[i] = (byte)(b ^ this.serverIv[i++]);
        }
        try {
            SecretKeySpec secretKey = new SecretKeySpec(this.serverKey, "AES");
            String AES_GCM_NOPADDING = "AES/GCM/NoPadding";
            Cipher aeadCipher = Cipher.getInstance(AES_GCM_NOPADDING);
            GCMParameterSpec parameterSpec = new GCMParameterSpec(128, nonce);
            aeadCipher.init(2, (Key)secretKey, parameterSpec);
            aeadCipher.updateAAD(associatedData);
            return aeadCipher.doFinal(message);
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new RuntimeException("Crypto error: " + e);
        }
    }

    public byte[] encryptPayload(byte[] message, byte[] associatedData) {
        ByteBuffer nonceInput = ByteBuffer.allocate(12);
        nonceInput.putInt(0);
        nonceInput.putLong(this.clientRecordCount);
        byte[] nonce = new byte[12];
        int i = 0;
        for (byte b : nonceInput.array()) {
            nonce[i] = (byte)(b ^ this.clientIv[i++]);
        }
        try {
            SecretKeySpec secretKey = new SecretKeySpec(this.clientKey, "AES");
            String AES_GCM_NOPADDING = "AES/GCM/NoPadding";
            Cipher aeadCipher = Cipher.getInstance(AES_GCM_NOPADDING);
            GCMParameterSpec parameterSpec = new GCMParameterSpec(128, nonce);
            aeadCipher.init(1, (Key)secretKey, parameterSpec);
            aeadCipher.updateAAD(associatedData);
            return aeadCipher.doFinal(message);
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new RuntimeException("Crypto error: " + e);
        }
    }

    public short getHashLength() {
        return 32;
    }

    public byte[] getClientEarlyTrafficSecret() {
        return this.clientEarlyTrafficSecret;
    }

    public byte[] getClientHandshakeTrafficSecret() {
        return this.clientHandshakeTrafficSecret;
    }

    public byte[] getServerHandshakeTrafficSecret() {
        return this.serverHandshakeTrafficSecret;
    }

    public byte[] getClientApplicationTrafficSecret() {
        return this.clientApplicationTrafficSecret;
    }

    public byte[] getServerApplicationTrafficSecret() {
        return this.serverApplicationTrafficSecret;
    }

    public void setOwnKey(PrivateKey clientPrivateKey) {
        this.clientPrivateKey = clientPrivateKey;
    }

    public void setPskSelected(int selectedIdentity) {
        this.pskSelected = true;
    }

    public void setNoPskSelected() {
        if (this.psk != null && !this.pskSelected) {
            this.computeEarlySecret(new byte[32]);
        }
    }

    public void setPeerKey(PublicKey serverSharedKey) {
        this.serverSharedKey = serverSharedKey;
    }
}

