/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.server;

import java.io.IOException;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import net.luminis.quic.CryptoStream;
import net.luminis.quic.DecryptionException;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.FrameProcessor2;
import net.luminis.quic.GlobalAckGenerator;
import net.luminis.quic.HandshakeState;
import net.luminis.quic.IdleTimer;
import net.luminis.quic.InvalidPacketException;
import net.luminis.quic.MissingKeysException;
import net.luminis.quic.PacketProcessor;
import net.luminis.quic.PnSpace;
import net.luminis.quic.QuicConnectionImpl;
import net.luminis.quic.QuicConstants;
import net.luminis.quic.QuicStream;
import net.luminis.quic.Role;
import net.luminis.quic.TransportError;
import net.luminis.quic.TransportParameters;
import net.luminis.quic.Version;
import net.luminis.quic.frame.AckFrame;
import net.luminis.quic.frame.HandshakeDoneFrame;
import net.luminis.quic.frame.NewConnectionIdFrame;
import net.luminis.quic.frame.NewTokenFrame;
import net.luminis.quic.frame.PathChallengeFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.frame.RetireConnectionIdFrame;
import net.luminis.quic.log.LogProxy;
import net.luminis.quic.log.Logger;
import net.luminis.quic.packet.HandshakePacket;
import net.luminis.quic.packet.InitialPacket;
import net.luminis.quic.packet.QuicPacket;
import net.luminis.quic.packet.RetryPacket;
import net.luminis.quic.packet.ShortHeaderPacket;
import net.luminis.quic.packet.VersionNegotiationPacket;
import net.luminis.quic.packet.ZeroRttPacket;
import net.luminis.quic.send.SenderImpl;
import net.luminis.quic.server.ApplicationProtocolRegistry;
import net.luminis.quic.server.ServerConnection;
import net.luminis.quic.stream.FlowControl;
import net.luminis.quic.stream.StreamManager;
import net.luminis.quic.tls.QuicTransportParametersExtension;
import net.luminis.tls.NewSessionTicket;
import net.luminis.tls.TlsProtocolException;
import net.luminis.tls.alert.MissingExtensionAlert;
import net.luminis.tls.alert.NoApplicationProtocolAlert;
import net.luminis.tls.extension.ApplicationLayerProtocolNegotiationExtension;
import net.luminis.tls.extension.Extension;
import net.luminis.tls.handshake.CertificateMessage;
import net.luminis.tls.handshake.CertificateVerifyMessage;
import net.luminis.tls.handshake.EncryptedExtensions;
import net.luminis.tls.handshake.FinishedMessage;
import net.luminis.tls.handshake.ServerHello;
import net.luminis.tls.handshake.ServerMessageSender;
import net.luminis.tls.handshake.TlsEngine;
import net.luminis.tls.handshake.TlsServerEngine;
import net.luminis.tls.handshake.TlsServerEngineFactory;
import net.luminis.tls.handshake.TlsStatusEventHandler;
import net.luminis.tls.util.ByteUtils;

public class ServerConnectionImpl
extends QuicConnectionImpl
implements ServerConnection,
TlsStatusEventHandler {
    private static final int TOKEN_SIZE = 37;
    private final Random random;
    private final SenderImpl sender;
    private final InetSocketAddress initialClientAddress;
    private final byte[] connectionId;
    private final byte[] peerConnectionId;
    private final boolean retryRequired;
    private final GlobalAckGenerator ackGenerator;
    private final List<FrameProcessor2<AckFrame>> ackProcessors = new CopyOnWriteArrayList<FrameProcessor2<AckFrame>>();
    private final TlsServerEngine tlsEngine;
    private final byte[] originalDcid;
    private final ApplicationProtocolRegistry applicationProtocolRegistry;
    private final Consumer<byte[]> closeCallback;
    private final StreamManager streamManager;
    private final int initialMaxStreamData;
    private final int maxOpenStreamsUni;
    private final int maxOpenStreamsBidi;
    private final byte[] token;
    private volatile String negotiatedApplicationProtocol;
    private int maxIdleTimeoutInSeconds;
    private volatile long bytesReceived;
    private volatile boolean addressValidated;

    protected ServerConnectionImpl(Version quicVersion, DatagramSocket serverSocket, InetSocketAddress initialClientAddress, byte[] connectionId, byte[] dcid, byte[] originalDcid, TlsServerEngineFactory tlsServerEngineFactory, boolean retryRequired, ApplicationProtocolRegistry applicationProtocolRegistry, Integer initialRtt, Consumer<byte[]> closeCallback, Logger log) {
        super(quicVersion, Role.Server, null, new LogProxy(log, originalDcid));
        this.initialClientAddress = initialClientAddress;
        this.connectionId = connectionId;
        this.peerConnectionId = dcid;
        this.originalDcid = originalDcid;
        this.retryRequired = retryRequired;
        this.applicationProtocolRegistry = applicationProtocolRegistry;
        this.closeCallback = closeCallback;
        this.tlsEngine = tlsServerEngineFactory.createServerEngine(new TlsMessageSender(), this);
        this.idleTimer = new IdleTimer(this, log);
        this.sender = new SenderImpl(quicVersion, ServerConnectionImpl.getMaxPacketSize(), serverSocket, initialClientAddress, this, initialRtt, this.log);
        if (!retryRequired) {
            this.sender.setAntiAmplificationLimit(0);
        }
        this.idleTimer.setPtoSupplier(this.sender::getPto);
        this.ackGenerator = this.sender.getGlobalAckGenerator();
        this.registerProcessor(this.ackGenerator);
        if (retryRequired) {
            this.random = new SecureRandom();
            this.token = new byte[37];
            this.random.nextBytes(this.token);
        } else {
            this.random = null;
            this.token = null;
        }
        this.connectionSecrets.computeInitialKeys(originalDcid);
        this.sender.start(this.connectionSecrets);
        this.maxIdleTimeoutInSeconds = 30;
        this.initialMaxStreamData = 1000000;
        this.maxOpenStreamsUni = 10;
        this.maxOpenStreamsBidi = 100;
        this.streamManager = new StreamManager(this, Role.Server, log, this.maxOpenStreamsUni, this.maxOpenStreamsBidi);
        this.log.getQLog().emitConnectionCreatedEvent(Instant.now());
    }

    @Override
    public void abortConnection(Throwable error) {
        this.log.error(this.toString() + " aborted due to internal error", error);
        this.closeCallback.accept(this.connectionId);
    }

    @Override
    protected SenderImpl getSender() {
        return this.sender;
    }

    @Override
    protected TlsEngine getTlsEngine() {
        return this.tlsEngine;
    }

    @Override
    protected GlobalAckGenerator getAckGenerator() {
        return this.ackGenerator;
    }

    @Override
    protected StreamManager getStreamManager() {
        return this.streamManager;
    }

    @Override
    public long getInitialMaxStreamData() {
        return this.initialMaxStreamData;
    }

    @Override
    public int getMaxShortHeaderPacketOverhead() {
        return 1 + this.peerConnectionId.length + 4 + 16;
    }

    @Override
    protected int getSourceConnectionIdLength() {
        return this.connectionId.length;
    }

    public byte[] getConnectionId() {
        return this.connectionId;
    }

    @Override
    public byte[] getSourceConnectionId() {
        return this.connectionId;
    }

    @Override
    public byte[] getDestinationConnectionId() {
        return this.peerConnectionId;
    }

    @Override
    public void registerProcessor(FrameProcessor2<AckFrame> ackProcessor) {
        this.ackProcessors.add(ackProcessor);
    }

    @Override
    public void earlySecretsKnown() {
    }

    @Override
    public void handshakeSecretsKnown() {
        this.connectionSecrets.computeHandshakeSecrets(this.tlsEngine, this.tlsEngine.getSelectedCipher());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void handshakeFinished() {
        this.connectionSecrets.computeApplicationSecrets(this.tlsEngine);
        this.getSender().discard(PnSpace.Handshake, "tls handshake confirmed");
        this.sendHandshakeDone(new HandshakeDoneFrame(this.quicVersion));
        this.connectionState = QuicConnectionImpl.Status.Connected;
        HandshakeState handshakeState = this.handshakeState;
        synchronized (handshakeState) {
            if (this.handshakeState.transitionAllowed(HandshakeState.Confirmed)) {
                this.handshakeState = HandshakeState.Confirmed;
                this.handshakeStateListeners.forEach(l -> l.handshakeStateChangedEvent(this.handshakeState));
            } else {
                this.log.debug("Handshake state cannot be set to Confirmed");
            }
        }
        this.applicationProtocolRegistry.startApplicationProtocolConnection(this.negotiatedApplicationProtocol, this);
    }

    private void sendHandshakeDone(QuicFrame frame) {
        this.send(frame, this::sendHandshakeDone);
    }

    @Override
    public void newSessionTicketReceived(NewSessionTicket ticket) {
    }

    @Override
    public void extensionsReceived(List<Extension> extensions) throws TlsProtocolException {
        Optional<Extension> alpnExtension = extensions.stream().filter(ext -> ext instanceof ApplicationLayerProtocolNegotiationExtension).findFirst();
        if (alpnExtension.isEmpty()) {
            throw new MissingExtensionAlert("missing application layer protocol negotiation extension");
        }
        List<String> requestedProtocols = ((ApplicationLayerProtocolNegotiationExtension)alpnExtension.get()).getProtocols();
        Optional<String> applicationProtocol = this.applicationProtocolRegistry.selectSupportedApplicationProtocol(requestedProtocols);
        applicationProtocol.map(protocol -> {
            this.tlsEngine.addServerExtensions(new ApplicationLayerProtocolNegotiationExtension((String)protocol));
            return protocol;
        }).map(selectedProtocol -> {
            this.negotiatedApplicationProtocol = selectedProtocol;
            return this.negotiatedApplicationProtocol;
        }).orElseThrow(() -> new NoApplicationProtocolAlert(requestedProtocols));
        Optional<Extension> tpExtension = extensions.stream().filter(ext -> ext instanceof QuicTransportParametersExtension).findFirst();
        if (tpExtension.isEmpty()) {
            throw new MissingExtensionAlert("missing quic transport parameters extension");
        }
        try {
            this.validateAndProcess(((QuicTransportParametersExtension)tpExtension.get()).getTransportParameters());
        }
        catch (TransportError transportParameterError) {
            throw new TlsProtocolException("transport parameter error", transportParameterError);
        }
        TransportParameters serverTransportParams = new TransportParameters(this.maxIdleTimeoutInSeconds, this.initialMaxStreamData, this.maxOpenStreamsBidi, this.maxOpenStreamsUni);
        serverTransportParams.setDisableMigration(true);
        serverTransportParams.setInitialSourceConnectionId(this.connectionId);
        serverTransportParams.setOriginalDestinationConnectionId(this.originalDcid);
        if (this.retryRequired) {
            serverTransportParams.setRetrySourceConnectionId(this.connectionId);
        }
        this.tlsEngine.addServerExtensions(new QuicTransportParametersExtension(this.quicVersion, serverTransportParams, Role.Server));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected QuicPacket parsePacket(ByteBuffer data) throws MissingKeysException, DecryptionException, InvalidPacketException {
        try {
            return super.parsePacket(data);
        }
        catch (DecryptionException decryptionException) {
            if (this.retryRequired && (data.get(0) & 0xF0) == 192) {
                try {
                    data.rewind();
                    this.connectionSecrets.computeInitialKeys(this.originalDcid);
                    QuicPacket quicPacket = super.parsePacket(data);
                    return quicPacket;
                }
                finally {
                    this.connectionSecrets.computeInitialKeys(this.connectionId);
                }
            }
            throw decryptionException;
        }
    }

    @Override
    public void parseAndProcessPackets(int datagram, Instant timeReceived, ByteBuffer data, QuicPacket parsedPacket) {
        if (InitialPacket.isInitial(data) && data.limit() < 1200) {
            return;
        }
        this.bytesReceived += (long)data.remaining();
        if (!this.addressValidated) {
            this.sender.setAntiAmplificationLimit(3 * (int)this.bytesReceived);
        }
        super.parseAndProcessPackets(datagram, timeReceived, data, parsedPacket);
    }

    @Override
    public PacketProcessor.ProcessResult process(InitialPacket packet, Instant time) {
        assert (Arrays.equals(packet.getDestinationConnectionId(), this.connectionId) || Arrays.equals(packet.getDestinationConnectionId(), this.originalDcid));
        if (this.retryRequired) {
            if (packet.getToken() == null) {
                this.sendRetry();
                this.connectionSecrets.computeInitialKeys(this.connectionId);
                return PacketProcessor.ProcessResult.Abort;
            }
            if (!Arrays.equals(packet.getToken(), this.token)) {
                this.immediateCloseWithError(EncryptionLevel.Initial, QuicConstants.TransportErrorCode.INVALID_TOKEN.value, null);
                return PacketProcessor.ProcessResult.Abort;
            }
            this.addressValidated = true;
            this.sender.unsetAntiAmplificationLimit();
            this.processFrames(packet, time);
            return PacketProcessor.ProcessResult.Continue;
        }
        this.processFrames(packet, time);
        return PacketProcessor.ProcessResult.Continue;
    }

    private void sendRetry() {
        RetryPacket retry = new RetryPacket(this.quicVersion, this.connectionId, this.getDestinationConnectionId(), this.getOriginalDestinationConnectionId(), this.token);
        this.sender.send(retry);
    }

    @Override
    public PacketProcessor.ProcessResult process(ShortHeaderPacket packet, Instant time) {
        this.processFrames(packet, time);
        return PacketProcessor.ProcessResult.Continue;
    }

    @Override
    public PacketProcessor.ProcessResult process(VersionNegotiationPacket packet, Instant time) {
        return PacketProcessor.ProcessResult.Abort;
    }

    @Override
    public PacketProcessor.ProcessResult process(HandshakePacket packet, Instant time) {
        if (!this.addressValidated) {
            this.addressValidated = true;
            this.sender.unsetAntiAmplificationLimit();
        }
        this.sender.discard(PnSpace.Initial, "first handshake packet received");
        this.processFrames(packet, time);
        return PacketProcessor.ProcessResult.Continue;
    }

    @Override
    public PacketProcessor.ProcessResult process(RetryPacket packet, Instant time) {
        return PacketProcessor.ProcessResult.Abort;
    }

    @Override
    public PacketProcessor.ProcessResult process(ZeroRttPacket packet, Instant time) {
        return PacketProcessor.ProcessResult.Continue;
    }

    @Override
    public void process(QuicFrame frame, QuicPacket packet, Instant timeReceived) {
    }

    @Override
    public void process(AckFrame ackFrame, QuicPacket packet, Instant timeReceived) {
        this.ackProcessors.forEach(processor -> processor.process(ackFrame, packet.getPnSpace(), timeReceived));
    }

    @Override
    public void process(HandshakeDoneFrame handshakeDoneFrame, QuicPacket packet, Instant timeReceived) {
    }

    @Override
    public void process(NewConnectionIdFrame newConnectionIdFrame, QuicPacket packet, Instant timeReceived) {
    }

    @Override
    public void process(NewTokenFrame newTokenFrame, QuicPacket packet, Instant timeReceived) {
    }

    @Override
    public void process(PathChallengeFrame pathChallengeFrame, QuicPacket packet, Instant timeReceived) {
    }

    @Override
    public void process(RetireConnectionIdFrame retireConnectionIdFrame, QuicPacket packet, Instant timeReceived) {
    }

    @Override
    protected void terminate() {
        super.terminate();
        this.log.getQLog().emitConnectionTerminatedEvent();
        this.closeCallback.accept(this.connectionId);
    }

    private void validateAndProcess(TransportParameters transportParameters) throws TransportError {
        if (transportParameters.getInitialMaxStreamsBidi() > 0x1000000000000000L) {
            throw new TransportError(QuicConstants.TransportErrorCode.TRANSPORT_PARAMETER_ERROR);
        }
        if (transportParameters.getMaxUdpPayloadSize() < 1200) {
            throw new TransportError(QuicConstants.TransportErrorCode.TRANSPORT_PARAMETER_ERROR);
        }
        if (transportParameters.getAckDelayExponent() > 20) {
            throw new TransportError(QuicConstants.TransportErrorCode.TRANSPORT_PARAMETER_ERROR);
        }
        if (transportParameters.getMaxAckDelay() > 16384) {
            throw new TransportError(QuicConstants.TransportErrorCode.TRANSPORT_PARAMETER_ERROR);
        }
        if (transportParameters.getActiveConnectionIdLimit() < 2) {
            throw new TransportError(QuicConstants.TransportErrorCode.TRANSPORT_PARAMETER_ERROR);
        }
        if (!Arrays.equals(transportParameters.getInitialSourceConnectionId(), this.peerConnectionId)) {
            throw new TransportError(QuicConstants.TransportErrorCode.TRANSPORT_PARAMETER_ERROR);
        }
        this.determineIdleTimeout(this.maxIdleTimeoutInSeconds * 1000, transportParameters.getMaxIdleTimeout());
        this.flowController = new FlowControl(Role.Server, transportParameters.getInitialMaxData(), transportParameters.getInitialMaxStreamDataBidiLocal(), transportParameters.getInitialMaxStreamDataBidiRemote(), transportParameters.getInitialMaxStreamDataUni(), this.log);
        this.streamManager.setFlowController(this.flowController);
        this.streamManager.setInitialMaxStreamsBidi(transportParameters.getInitialMaxStreamsBidi());
        this.streamManager.setInitialMaxStreamsUni(transportParameters.getInitialMaxStreamsUni());
    }

    @Override
    public InetAddress getInitialClientAddress() {
        return this.initialClientAddress.getAddress();
    }

    public boolean isClosed() {
        return this.connectionState == QuicConnectionImpl.Status.Closed;
    }

    public byte[] getOriginalDestinationConnectionId() {
        return this.originalDcid;
    }

    @Override
    public void setMaxAllowedBidirectionalStreams(int max) {
    }

    @Override
    public void setMaxAllowedUnidirectionalStreams(int max) {
    }

    @Override
    public void setDefaultStreamReceiveBufferSize(long size) {
    }

    @Override
    public void setPeerInitiatedStreamCallback(Consumer<QuicStream> streamConsumer) {
        this.streamManager.setPeerInitiatedStreamCallback(streamConsumer);
    }

    public String toString() {
        return "ServerConnection[" + ByteUtils.bytesToHex(this.originalDcid) + "]";
    }

    private class TlsMessageSender
    implements ServerMessageSender {
        private TlsMessageSender() {
        }

        @Override
        public void send(ServerHello sh) {
            CryptoStream cryptoStream = ServerConnectionImpl.this.getCryptoStream(EncryptionLevel.Initial);
            cryptoStream.write(sh, false);
            ServerConnectionImpl.this.log.sentPacketInfo(cryptoStream.toStringSent());
        }

        @Override
        public void send(EncryptedExtensions ee) {
            ServerConnectionImpl.this.getCryptoStream(EncryptionLevel.Handshake).write(ee, false);
        }

        @Override
        public void send(CertificateMessage cm) throws IOException {
            ServerConnectionImpl.this.getCryptoStream(EncryptionLevel.Handshake).write(cm, false);
        }

        @Override
        public void send(CertificateVerifyMessage cv) throws IOException {
            ServerConnectionImpl.this.getCryptoStream(EncryptionLevel.Handshake).write(cv, false);
        }

        @Override
        public void send(FinishedMessage finished) throws IOException {
            CryptoStream cryptoStream = ServerConnectionImpl.this.getCryptoStream(EncryptionLevel.Handshake);
            cryptoStream.write(finished, false);
            ServerConnectionImpl.this.log.sentPacketInfo(cryptoStream.toStringSent());
        }
    }
}

