/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.server;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.RawPacket;
import net.luminis.quic.Receiver;
import net.luminis.quic.UnknownVersionException;
import net.luminis.quic.Version;
import net.luminis.quic.log.FileLogger;
import net.luminis.quic.log.Logger;
import net.luminis.quic.log.SysOutLogger;
import net.luminis.quic.packet.VersionNegotiationPacket;
import net.luminis.quic.run.KwikVersion;
import net.luminis.quic.server.ApplicationProtocolConnectionFactory;
import net.luminis.quic.server.ApplicationProtocolRegistry;
import net.luminis.quic.server.ConnectionSource;
import net.luminis.quic.server.ServerConnectionCandidate;
import net.luminis.quic.server.ServerConnectionFactory;
import net.luminis.quic.server.ServerConnectionProxy;
import net.luminis.quic.server.ServerConnectionRegistry;
import net.luminis.quic.server.h09.Http09ApplicationProtocolFactory;
import net.luminis.tls.handshake.TlsServerEngineFactory;
import net.luminis.tls.util.ByteUtils;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

public class Server
implements ServerConnectionRegistry {
    private static final int MINIMUM_LONG_HEADER_LENGTH = 7;
    private static final int CONNECTION_ID_LENGTH = 4;
    private final Receiver receiver;
    private final Logger log;
    private final List<Version> supportedVersions;
    private final List<Integer> supportedVersionIds;
    private final DatagramSocket serverSocket;
    private final boolean requireRetry;
    private Integer initalRtt = 100;
    private Map<ConnectionSource, ServerConnectionProxy> currentConnections;
    private TlsServerEngineFactory tlsEngineFactory;
    private final ServerConnectionFactory serverConnectionFactory;
    private ApplicationProtocolRegistry applicationProtocolRegistry;

    private static void usageAndExit() {
        System.err.println("Usage: [--noRetry] cert file, cert key file, port number [www dir]");
        System.exit(1);
    }

    public static void main(String[] rawArgs) throws Exception {
        File certificateKeyFile;
        Options cmdLineOptions = new Options();
        cmdLineOptions.addOption(null, "noRetry", false, "disable always use retry");
        DefaultParser parser = new DefaultParser();
        CommandLine cmd = null;
        try {
            cmd = parser.parse(cmdLineOptions, rawArgs);
        }
        catch (ParseException argError) {
            System.out.println("Invalid argument: " + argError.getMessage());
            Server.usageAndExit();
        }
        List<String> args = cmd.getArgList();
        if (args.size() < 3) {
            Server.usageAndExit();
        }
        boolean requireRetry = !cmd.hasOption("noRetry");
        File certificateFile = new File(args.get(0));
        if (!certificateFile.exists()) {
            System.err.println("Cannot open certificate file " + args.get(0));
            System.exit(1);
        }
        if (!(certificateKeyFile = new File(args.get(1))).exists()) {
            System.err.println("Cannot open certificate file " + args.get(1));
            System.exit(1);
        }
        int port = Integer.parseInt(args.get(2));
        File wwwDir = null;
        if (!(args.size() <= 3 || (wwwDir = new File(args.get(3))).exists() && wwwDir.isDirectory() && wwwDir.canRead())) {
            System.err.println("Cannot read www dir '" + wwwDir + "'");
            System.exit(1);
        }
        ArrayList<Version> supportedVersions = new ArrayList<Version>();
        supportedVersions.addAll(List.of(Version.IETF_draft_29, Version.IETF_draft_30, Version.IETF_draft_31, Version.IETF_draft_32));
        supportedVersions.add(Version.QUIC_version_1);
        new Server(port, (InputStream)new FileInputStream(certificateFile), (InputStream)new FileInputStream(certificateKeyFile), supportedVersions, requireRetry, wwwDir).start();
    }

    public Server(int port, InputStream certificateFile, InputStream certificateKeyFile, List<Version> supportedVersions, boolean requireRetry, File dir) throws Exception {
        this(new DatagramSocket(port), certificateFile, certificateKeyFile, supportedVersions, requireRetry, dir);
    }

    public Server(DatagramSocket socket, InputStream certificateFile, InputStream certificateKeyFile, List<Version> supportedVersions, boolean requireRetry, File dir) throws Exception {
        this.serverSocket = socket;
        this.supportedVersions = supportedVersions;
        this.requireRetry = requireRetry;
        File logDir = new File("/logs");
        this.log = logDir.exists() && logDir.isDirectory() && logDir.canWrite() ? new FileLogger(new File(logDir, "kwikserver.log")) : new SysOutLogger();
        this.log.timeFormat(Logger.TimeFormat.Long);
        this.log.logWarning(true);
        this.log.logInfo(true);
        this.tlsEngineFactory = new TlsServerEngineFactory(certificateFile, certificateKeyFile);
        this.applicationProtocolRegistry = new ApplicationProtocolRegistry();
        this.serverConnectionFactory = new ServerConnectionFactory(4, this.serverSocket, this.tlsEngineFactory, this.requireRetry, this.applicationProtocolRegistry, this.initalRtt, this::removeConnection, this.log);
        this.supportedVersionIds = supportedVersions.stream().map(version -> version.getId()).collect(Collectors.toList());
        if (dir != null) {
            this.registerApplicationLayerProtocols(dir);
        }
        this.currentConnections = new ConcurrentHashMap<ConnectionSource, ServerConnectionProxy>();
        this.receiver = new Receiver(this.serverSocket, this.log, exception -> System.exit(9));
        this.log.info("Kwik server " + KwikVersion.getVersion() + " started; supported application protcols: " + this.applicationProtocolRegistry.getRegisteredApplicationProtocols());
    }

    private void start() {
        this.receiver.start();
        new Thread(this::receiveLoop, "server receive loop").start();
    }

    private void registerApplicationLayerProtocols(File wwwDir) {
        ApplicationProtocolConnectionFactory http3ApplicationProtocolConnectionFactory = null;
        try {
            Class<?> http3FactoryClass = this.getClass().getClassLoader().loadClass("net.luminis.http3.server.Http3ApplicationProtocolFactory");
            http3ApplicationProtocolConnectionFactory = (ApplicationProtocolConnectionFactory)http3FactoryClass.getDeclaredConstructor(File.class).newInstance(wwwDir);
            this.log.info("Loading Flupke H3 server plugin");
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException http3FactoryClass) {
            // empty catch block
        }
        Http09ApplicationProtocolFactory http09ApplicationProtocolFactory = new Http09ApplicationProtocolFactory(wwwDir);
        ApplicationProtocolConnectionFactory http3ApplicationProtocolFactory = http3ApplicationProtocolConnectionFactory;
        this.supportedVersions.forEach(version -> {
            Object protocol = "hq";
            String versionSuffix = version.getDraftVersion();
            protocol = !versionSuffix.isBlank() ? (String)protocol + "-" + versionSuffix : "hq-interop";
            this.applicationProtocolRegistry.registerApplicationProtocol((String)protocol, http09ApplicationProtocolFactory);
            if (http3ApplicationProtocolFactory != null) {
                String h3Protocol = ((String)protocol).replace("hq-interop", "h3").replace("hq", "h3");
                this.applicationProtocolRegistry.registerApplicationProtocol(h3Protocol, http3ApplicationProtocolFactory);
            }
        });
    }

    private void receiveLoop() {
        while (true) {
            try {
                while (true) {
                    RawPacket rawPacket = this.receiver.get((int)Duration.ofDays(3650L).toSeconds());
                    this.process(rawPacket);
                }
            }
            catch (InterruptedException e) {
                this.log.error("receiver interrupted (ignoring)");
            }
            catch (Exception runtimeError) {
                this.log.error("Uncaught exception in server receive loop", runtimeError);
                continue;
            }
            break;
        }
    }

    void process(RawPacket rawPacket) {
        ByteBuffer data = rawPacket.getData();
        byte flags = data.get();
        data.rewind();
        if ((flags & 0xC0) == 192) {
            this.processLongHeaderPacket(new InetSocketAddress(rawPacket.getAddress(), rawPacket.getPort()), data);
        } else if ((flags & 0xC0) == 64) {
            this.processShortHeaderPacket(new InetSocketAddress(rawPacket.getAddress(), rawPacket.getPort()), data);
        } else {
            this.log.warn(String.format("Invalid Quic packet (flags: %02x) is discarded", flags));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void processLongHeaderPacket(InetSocketAddress clientAddress, ByteBuffer data) {
        if (data.remaining() >= 7) {
            data.position(1);
            int version = data.getInt();
            data.position(5);
            int dcidLength = data.get() & 0xFF;
            if (dcidLength > 20) {
                if (this.initialWithUnspportedVersion(data, version)) {
                    this.sendVersionNegotiationPacket(clientAddress, data, dcidLength);
                }
                return;
            }
            if (data.remaining() >= dcidLength + 1) {
                byte[] dcid = new byte[dcidLength];
                data.get(dcid);
                int scidLength = data.get() & 0xFF;
                if (data.remaining() >= scidLength) {
                    byte[] scid = new byte[scidLength];
                    data.get(scid);
                    data.rewind();
                    Optional<ServerConnectionProxy> connection = this.isExistingConnection(clientAddress, dcid);
                    if (connection.isEmpty()) {
                        Server server = this;
                        synchronized (server) {
                            if (this.mightStartNewConnection(data, version, dcid) && this.isExistingConnection(clientAddress, dcid).isEmpty()) {
                                connection = Optional.of(this.createNewConnection(version, clientAddress, scid, dcid));
                            } else if (this.initialWithUnspportedVersion(data, version)) {
                                this.log.received(Instant.now(), 0, EncryptionLevel.Initial, dcid, scid);
                                this.sendVersionNegotiationPacket(clientAddress, data, dcidLength);
                            }
                        }
                    }
                    connection.ifPresent(c -> c.parsePackets(0, Instant.now(), data));
                }
            }
        }
    }

    private void processShortHeaderPacket(InetSocketAddress clientAddress, ByteBuffer data) {
        byte[] dcid = new byte[4];
        data.position(1);
        data.get(dcid);
        data.rewind();
        Optional<ServerConnectionProxy> connection = this.isExistingConnection(clientAddress, dcid);
        connection.ifPresentOrElse(c -> c.parsePackets(0, Instant.now(), data), () -> this.log.warn("Discarding short header packet addressing non existent connection " + ByteUtils.bytesToHex(dcid)));
    }

    private boolean mightStartNewConnection(ByteBuffer packetBytes, int version, byte[] dcid) {
        if (dcid.length >= 8) {
            return this.supportedVersionIds.contains(version);
        }
        return false;
    }

    private boolean initialWithUnspportedVersion(ByteBuffer packetBytes, int version) {
        packetBytes.rewind();
        int flags = packetBytes.get() & 0xFF;
        if ((flags & 0xF0) == 192 && packetBytes.limit() >= 1200) {
            return !this.supportedVersionIds.contains(version);
        }
        return false;
    }

    private ServerConnectionProxy createNewConnection(int versionValue, InetSocketAddress clientAddress, byte[] scid, byte[] dcid) {
        try {
            Version version = Version.parse(versionValue);
            ServerConnectionCandidate connectionCandidate = new ServerConnectionCandidate(version, clientAddress, scid, dcid, this.serverConnectionFactory, this, this.log);
            this.currentConnections.put(new ConnectionSource(dcid), connectionCandidate);
            return connectionCandidate;
        }
        catch (UnknownVersionException e) {
            throw new RuntimeException();
        }
    }

    private void removeConnection(byte[] cid) {
        ServerConnectionProxy removed = this.currentConnections.remove(new ConnectionSource(cid));
        this.currentConnections.remove(new ConnectionSource(removed.getOriginalDestinationConnectionId()));
        if (removed == null) {
            this.log.error("Cannot remove connection with cid " + ByteUtils.bytesToHex(cid));
        } else if (!removed.isClosed()) {
            this.log.error("Removed connection with cid " + ByteUtils.bytesToHex(cid) + " that is not closed...");
        }
        removed.terminate();
    }

    private Optional<ServerConnectionProxy> isExistingConnection(InetSocketAddress clientAddress, byte[] dcid) {
        return Optional.ofNullable(this.currentConnections.get(new ConnectionSource(dcid)));
    }

    private void sendVersionNegotiationPacket(InetSocketAddress clientAddress, ByteBuffer data, int dcidLength) {
        data.rewind();
        if (data.remaining() >= 6 + dcidLength + 1) {
            byte[] dcid = new byte[dcidLength];
            data.position(6);
            data.get(dcid);
            int scidLength = data.get() & 0xFF;
            byte[] scid = new byte[scidLength];
            if (scidLength > 0) {
                data.get(scid);
            }
            VersionNegotiationPacket versionNegotiationPacket = new VersionNegotiationPacket(this.supportedVersions, dcid, scid);
            byte[] packetBytes = versionNegotiationPacket.generatePacketBytes(null, null);
            DatagramPacket datagram = new DatagramPacket(packetBytes, packetBytes.length, clientAddress.getAddress(), clientAddress.getPort());
            try {
                this.serverSocket.send(datagram);
                this.log.sent(Instant.now(), versionNegotiationPacket);
            }
            catch (IOException e) {
                this.log.error("Sending version negotiation packet failed", e);
            }
        }
    }

    @Override
    public void registerConnection(ServerConnectionProxy connection, byte[] connectionId) {
        this.currentConnections.put(new ConnectionSource(connectionId), connection);
    }

    @Override
    public void deregisterConnection(ServerConnectionProxy connection, byte[] connectionId) {
        boolean removed = this.currentConnections.remove(new ConnectionSource(connectionId), connection);
        if (!removed && this.currentConnections.containsKey(new ConnectionSource(connectionId))) {
            this.log.error("Connection " + connection + " not removed, because " + this.currentConnections.get(new ConnectionSource(connectionId)) + " is registered for " + ByteUtils.bytesToHex(connectionId));
        }
    }
}

