/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.tls.extension;

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.AlgorithmParameters;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.XECPublicKey;
import java.security.spec.ECGenParameterSpec;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.NamedParameterSpec;
import java.security.spec.XECPublicKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import net.luminis.tls.TlsConstants;
import net.luminis.tls.TlsProtocolException;
import net.luminis.tls.alert.DecodeErrorException;
import net.luminis.tls.extension.Extension;
import net.luminis.tls.util.ByteUtils;

public class KeyShareExtension
extends Extension {
    public static final Map<TlsConstants.NamedGroup, Integer> CURVE_KEY_LENGTHS = Map.of(TlsConstants.NamedGroup.secp256r1, 65, TlsConstants.NamedGroup.x25519, 32, TlsConstants.NamedGroup.x448, 56);
    public static final List<TlsConstants.NamedGroup> supportedCurves = List.of(TlsConstants.NamedGroup.secp256r1, TlsConstants.NamedGroup.x25519, TlsConstants.NamedGroup.x448);
    private TlsConstants.HandshakeType handshakeType;
    private List<KeyShareEntry> keyShareEntries = new ArrayList<KeyShareEntry>();

    public KeyShareExtension(ECPublicKey publicKey, TlsConstants.NamedGroup ecCurve, TlsConstants.HandshakeType handshakeType) {
        this.handshakeType = handshakeType;
        if (!supportedCurves.contains((Object)ecCurve)) {
            throw new RuntimeException("Only curves supported: " + supportedCurves);
        }
        this.keyShareEntries.add(new ECKeyShareEntry(ecCurve, publicKey));
    }

    public KeyShareExtension(PublicKey publicKey, TlsConstants.NamedGroup ecCurve, TlsConstants.HandshakeType handshakeType) {
        this.handshakeType = handshakeType;
        if (!supportedCurves.contains((Object)ecCurve)) {
            throw new RuntimeException("Only curves supported: " + supportedCurves);
        }
        this.keyShareEntries.add(new KeyShareEntry(ecCurve, publicKey));
    }

    public KeyShareExtension(ByteBuffer buffer, TlsConstants.HandshakeType handshakeType) throws TlsProtocolException {
        this(buffer, handshakeType, false);
    }

    public KeyShareExtension(ByteBuffer buffer, TlsConstants.HandshakeType handshakeType, boolean helloRetryRequestType) throws TlsProtocolException {
        int extensionDataLength = this.parseExtensionHeader(buffer, TlsConstants.ExtensionType.key_share, 1);
        if (extensionDataLength < 2) {
            throw new DecodeErrorException("extension underflow");
        }
        if (handshakeType == TlsConstants.HandshakeType.client_hello) {
            int remaining;
            int keyShareEntriesSize = buffer.getShort();
            if (extensionDataLength != 2 + keyShareEntriesSize) {
                throw new DecodeErrorException("inconsistent length");
            }
            for (remaining = keyShareEntriesSize; remaining > 0; remaining -= this.parseKeyShareEntry(buffer, helloRetryRequestType)) {
            }
            if (remaining != 0) {
                throw new DecodeErrorException("inconsistent length");
            }
        } else if (handshakeType == TlsConstants.HandshakeType.server_hello) {
            int remaining = extensionDataLength;
            if ((remaining -= this.parseKeyShareEntry(buffer, helloRetryRequestType)) != 0) {
                throw new DecodeErrorException("inconsistent length");
            }
        } else {
            throw new IllegalArgumentException();
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    protected int parseKeyShareEntry(ByteBuffer buffer, boolean namedGroupOnly) throws TlsProtocolException {
        int startPosition = buffer.position();
        if (namedGroupOnly && buffer.remaining() < 2 || !namedGroupOnly && buffer.remaining() < 4) {
            throw new DecodeErrorException("extension underflow");
        }
        short namedGroupValue = buffer.getShort();
        TlsConstants.NamedGroup namedGroup = Stream.of(TlsConstants.NamedGroup.values()).filter(it -> it.value == namedGroupValue).findAny().orElseThrow(() -> new DecodeErrorException("Invalid named group"));
        if (!supportedCurves.contains((Object)namedGroup)) {
            throw new RuntimeException("Curve '" + namedGroup + "' not supported");
        }
        if (namedGroupOnly) {
            this.keyShareEntries.add(new ECKeyShareEntry(namedGroup, null));
            return buffer.position() - startPosition;
        } else {
            short keyLength = buffer.getShort();
            if (buffer.remaining() < keyLength) {
                throw new DecodeErrorException("extension underflow");
            }
            if (keyLength != CURVE_KEY_LENGTHS.get((Object)namedGroup)) {
                throw new DecodeErrorException("Invalid " + namedGroup.name() + " key length: " + keyLength);
            }
            if (namedGroup == TlsConstants.NamedGroup.secp256r1) {
                byte headerByte = buffer.get();
                if (headerByte != 4) throw new DecodeErrorException("EC keys must be in legacy form");
                byte[] keyData = new byte[keyLength - 1];
                buffer.get(keyData);
                ECPublicKey ecPublicKey = KeyShareExtension.rawToEncodedECPublicKey(namedGroup, keyData);
                this.keyShareEntries.add(new ECKeyShareEntry(namedGroup, ecPublicKey));
                return buffer.position() - startPosition;
            } else {
                if (namedGroup != TlsConstants.NamedGroup.x25519 && namedGroup != TlsConstants.NamedGroup.x448) return buffer.position() - startPosition;
                byte[] keyData = new byte[keyLength];
                buffer.get(keyData);
                PublicKey publicKey = KeyShareExtension.rawToEncodedXDHPublicKey(namedGroup, keyData);
                this.keyShareEntries.add(new KeyShareEntry(namedGroup, publicKey));
            }
        }
        return buffer.position() - startPosition;
    }

    @Override
    public byte[] getBytes() {
        short keyShareEntryLength;
        short extensionLength = keyShareEntryLength = (short)this.keyShareEntries.stream().map(ks -> ks.getNamedGroup()).mapToInt(g -> CURVE_KEY_LENGTHS.get(g)).map(s -> 4 + s).sum();
        if (this.handshakeType == TlsConstants.HandshakeType.client_hello) {
            extensionLength = (short)(extensionLength + 2);
        }
        ByteBuffer buffer = ByteBuffer.allocate(4 + extensionLength);
        buffer.putShort(TlsConstants.ExtensionType.key_share.value);
        buffer.putShort(extensionLength);
        if (this.handshakeType == TlsConstants.HandshakeType.client_hello) {
            buffer.putShort(keyShareEntryLength);
        }
        for (KeyShareEntry keyShare : this.keyShareEntries) {
            buffer.putShort(keyShare.getNamedGroup().value);
            buffer.putShort(CURVE_KEY_LENGTHS.get((Object)keyShare.getNamedGroup()).shortValue());
            if (keyShare.getNamedGroup() == TlsConstants.NamedGroup.secp256r1) {
                buffer.put((byte)4);
                byte[] affineX = ((ECPublicKey)keyShare.getKey()).getW().getAffineX().toByteArray();
                this.writeAffine(buffer, affineX);
                byte[] affineY = ((ECPublicKey)keyShare.getKey()).getW().getAffineY().toByteArray();
                this.writeAffine(buffer, affineY);
                continue;
            }
            if (keyShare.getNamedGroup() == TlsConstants.NamedGroup.x25519 || keyShare.getNamedGroup() == TlsConstants.NamedGroup.x448) {
                byte[] raw = ((XECPublicKey)keyShare.getKey()).getU().toByteArray();
                if (raw.length > CURVE_KEY_LENGTHS.get((Object)keyShare.getNamedGroup())) {
                    throw new RuntimeException("Invalid " + keyShare.getNamedGroup() + " key length: " + raw.length);
                }
                if (raw.length < CURVE_KEY_LENGTHS.get((Object)keyShare.getNamedGroup())) {
                    KeyShareExtension.reverse(raw);
                    byte[] padded = Arrays.copyOf(raw, (int)CURVE_KEY_LENGTHS.get((Object)keyShare.getNamedGroup()));
                    raw = padded;
                } else {
                    KeyShareExtension.reverse(raw);
                }
                buffer.put(raw);
                continue;
            }
            throw new RuntimeException();
        }
        return buffer.array();
    }

    public List<KeyShareEntry> getKeyShareEntries() {
        return this.keyShareEntries;
    }

    private void writeAffine(ByteBuffer buffer, byte[] affine) {
        if (affine.length == 32) {
            buffer.put(affine);
        } else if (affine.length < 32) {
            for (int i = 0; i < 32 - affine.length; ++i) {
                buffer.put((byte)0);
            }
            buffer.put(affine, 0, affine.length);
        } else if (affine.length > 32) {
            for (int i = 0; i < affine.length - 32; ++i) {
                if (affine[i] == 0) continue;
                throw new RuntimeException("W Affine more then 32 bytes, leading bytes not 0 " + ByteUtils.bytesToHex(affine));
            }
            buffer.put(affine, affine.length - 32, 32);
        }
    }

    static ECPublicKey rawToEncodedECPublicKey(TlsConstants.NamedGroup curveName, byte[] rawBytes) {
        try {
            KeyFactory kf = KeyFactory.getInstance("EC");
            byte[] x = Arrays.copyOfRange(rawBytes, 0, rawBytes.length / 2);
            byte[] y = Arrays.copyOfRange(rawBytes, rawBytes.length / 2, rawBytes.length);
            ECPoint w = new ECPoint(new BigInteger(1, x), new BigInteger(1, y));
            return (ECPublicKey)kf.generatePublic(new ECPublicKeySpec(w, KeyShareExtension.ecParameterSpecForCurve(curveName.name())));
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("Missing support for EC algorithm");
        }
        catch (InvalidKeySpecException e) {
            throw new RuntimeException("Inappropriate parameter specification");
        }
    }

    static ECParameterSpec ecParameterSpecForCurve(String curveName) {
        try {
            AlgorithmParameters params = AlgorithmParameters.getInstance("EC");
            params.init(new ECGenParameterSpec(curveName));
            return params.getParameterSpec(ECParameterSpec.class);
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("Missing support for EC algorithm");
        }
        catch (InvalidParameterSpecException e) {
            throw new RuntimeException("Inappropriate parameter specification");
        }
    }

    static PublicKey rawToEncodedXDHPublicKey(TlsConstants.NamedGroup curve, byte[] keyData) {
        try {
            KeyShareExtension.reverse(keyData);
            BigInteger u = new BigInteger(keyData);
            KeyFactory kf = KeyFactory.getInstance("XDH");
            NamedParameterSpec paramSpec = new NamedParameterSpec(curve.name().toUpperCase());
            XECPublicKeySpec pubSpec = new XECPublicKeySpec(paramSpec, u);
            return kf.generatePublic(pubSpec);
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("Missing support for EC algorithm");
        }
        catch (InvalidKeySpecException e) {
            throw new RuntimeException("Inappropriate parameter specification");
        }
    }

    public static void reverse(byte[] array) {
        if (array == null) {
            return;
        }
        int i = 0;
        for (int j = array.length - 1; j > i; --j, ++i) {
            byte tmp = array[j];
            array[j] = array[i];
            array[i] = tmp;
        }
    }

    public static class ECKeyShareEntry
    extends KeyShareEntry {
        private final ECPublicKey key;

        public ECKeyShareEntry(TlsConstants.NamedGroup namedGroup, ECPublicKey key) {
            super(namedGroup, key);
            this.namedGroup = namedGroup;
            this.key = key;
        }

        @Override
        public ECPublicKey getKey() {
            return this.key;
        }
    }

    public static class KeyShareEntry {
        protected TlsConstants.NamedGroup namedGroup;
        protected final PublicKey key;

        public KeyShareEntry(TlsConstants.NamedGroup namedGroup, PublicKey key) {
            this.namedGroup = namedGroup;
            this.key = key;
        }

        public TlsConstants.NamedGroup getNamedGroup() {
            return this.namedGroup;
        }

        public PublicKey getKey() {
            return this.key;
        }
    }
}

