/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.send;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import net.luminis.quic.AckGenerator;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.Version;
import net.luminis.quic.frame.AckFrame;
import net.luminis.quic.frame.PingFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.packet.HandshakePacket;
import net.luminis.quic.packet.QuicPacket;
import net.luminis.quic.packet.ShortHeaderPacket;
import net.luminis.quic.packet.ZeroRttPacket;
import net.luminis.quic.send.PacketNumberGenerator;
import net.luminis.quic.send.SendItem;
import net.luminis.quic.send.SendRequest;
import net.luminis.quic.send.SendRequestQueue;

public class PacketAssembler {
    protected static final Consumer<QuicFrame> EMPTY_CALLBACK = f -> {};
    protected final Version quicVersion;
    protected final EncryptionLevel level;
    protected final SendRequestQueue requestQueue;
    protected final AckGenerator ackGenerator;
    private final PacketNumberGenerator packetNumberGenerator;
    protected long nextPacketNumber;

    public PacketAssembler(Version version, EncryptionLevel level, SendRequestQueue requestQueue, AckGenerator ackGenerator) {
        this(version, level, requestQueue, ackGenerator, new PacketNumberGenerator());
    }

    public PacketAssembler(Version version, EncryptionLevel level, SendRequestQueue requestQueue, AckGenerator ackGenerator, PacketNumberGenerator pnGenerator) {
        this.quicVersion = version;
        this.level = level;
        this.requestQueue = requestQueue;
        this.ackGenerator = ackGenerator;
        this.packetNumberGenerator = pnGenerator;
    }

    Optional<SendItem> assemble(int remainingCwndSize, int availablePacketSize, byte[] sourceConnectionId, byte[] destinationConnectionId) {
        int available = Integer.min(remainingCwndSize, availablePacketSize - 3);
        Optional<QuicPacket> packet = Optional.empty();
        ArrayList<Consumer<QuicFrame>> callbacks = new ArrayList<Consumer<QuicFrame>>();
        AckFrame ackFrame = null;
        if (this.requestQueue.mustAndWillSendAck() && this.ackGenerator.hasNewAckToSend()) {
            packet = packet.or(() -> Optional.of(this.createPacket(sourceConnectionId, destinationConnectionId, null)));
            ackFrame = this.ackGenerator.generateAck().get();
            if (((QuicPacket)packet.get()).estimateLength(ackFrame.getBytes().length) <= availablePacketSize) {
                ((QuicPacket)packet.get()).addFrame(ackFrame);
                callbacks.add(EMPTY_CALLBACK);
                this.ackGenerator.registerAckSendWithPacket(ackFrame, ((QuicPacket)packet.get()).getPacketNumber());
            } else {
                this.requestQueue.addAckRequest();
                return Optional.empty();
            }
        }
        int optionalAckSize = 0;
        if (ackFrame == null && this.requestQueue.hasRequests() && this.ackGenerator.hasAckToSend()) {
            packet = packet.or(() -> Optional.of(this.createPacket(sourceConnectionId, destinationConnectionId, null)));
            ackFrame = this.ackGenerator.generateAck().orElse(null);
            if (ackFrame != null) {
                optionalAckSize = ackFrame.getBytes().length;
            }
        }
        if (this.requestQueue.hasProbeWithData()) {
            List<QuicFrame> probeData = this.requestQueue.getProbe();
            int estimatedSize = ((QuicPacket)(packet = packet.or(() -> Optional.of(this.createPacket(sourceConnectionId, destinationConnectionId, null)))).get()).estimateLength(probeData.stream().mapToInt(f -> f.getBytes().length).sum());
            if (estimatedSize > availablePacketSize) {
                PingFrame probeFrame = new PingFrame();
                if (((QuicPacket)packet.get()).estimateLength(((QuicFrame)probeFrame).getBytes().length) > availablePacketSize) {
                    return Optional.empty();
                }
                probeData = List.of(probeFrame);
            }
            packet = packet.or(() -> Optional.of(this.createPacket(sourceConnectionId, destinationConnectionId, null)));
            ((QuicPacket)packet.get()).setIsProbe(true);
            ((QuicPacket)packet.get()).addFrames(probeData);
            return Optional.of(new SendItem((QuicPacket)packet.get()));
        }
        if (this.requestQueue.hasRequests()) {
            packet = packet.or(() -> Optional.of(this.createPacket(sourceConnectionId, destinationConnectionId, null)));
            int estimatedSize = ((QuicPacket)packet.get()).estimateLength(1000) - 1000;
            while (estimatedSize < available) {
                int proposedSize = available - estimatedSize - optionalAckSize;
                Optional<SendRequest> next = this.requestQueue.next(proposedSize);
                if (next.isEmpty() && optionalAckSize > 0) {
                    proposedSize = available - estimatedSize;
                    next = this.requestQueue.next(proposedSize);
                }
                if (next.isEmpty()) break;
                QuicFrame nextFrame = next.get().getFrameSupplier().apply(proposedSize);
                if (nextFrame == null) continue;
                if (nextFrame.getBytes().length > proposedSize) {
                    throw new RuntimeException("supplier does not produce frame of right (max) size: " + nextFrame.getBytes().length + " > " + proposedSize + " frame: " + nextFrame);
                }
                ((QuicPacket)packet.get()).addFrame(nextFrame);
                callbacks.add(next.get().getLostCallback());
                if (optionalAckSize <= 0 || (estimatedSize += nextFrame.getBytes().length) + optionalAckSize > available) continue;
                ((QuicPacket)packet.get()).addFrame(ackFrame);
                callbacks.add(EMPTY_CALLBACK);
                this.ackGenerator.registerAckSendWithPacket(ackFrame, ((QuicPacket)packet.get()).getPacketNumber());
                estimatedSize += ackFrame.getBytes().length;
                optionalAckSize = 0;
            }
            if (((QuicPacket)packet.get()).getFrames().isEmpty()) {
                packet = Optional.empty();
                this.restorePacketNumber();
            }
        }
        if (this.requestQueue.hasProbe() && packet.isEmpty()) {
            packet = packet.or(() -> Optional.of(this.createPacket(sourceConnectionId, destinationConnectionId, null)));
            this.requestQueue.getProbe();
            ((QuicPacket)packet.get()).setIsProbe(true);
            ((QuicPacket)packet.get()).addFrame(new PingFrame());
            callbacks.add(EMPTY_CALLBACK);
        }
        return packet.map(p -> new SendItem((QuicPacket)p, this.createPacketLostCallback((QuicPacket)p, (List<Consumer<QuicFrame>>)callbacks)));
    }

    protected long nextPacketNumber() {
        return this.packetNumberGenerator.nextPacketNumber();
    }

    protected void restorePacketNumber() {
        this.packetNumberGenerator.restorePacketNumber();
    }

    private Consumer<QuicPacket> createPacketLostCallback(QuicPacket packet, List<Consumer<QuicFrame>> callbacks) {
        if (packet.getFrames().size() != callbacks.size()) {
            throw new IllegalStateException();
        }
        return lostPacket -> {
            for (int i = 0; i < callbacks.size(); ++i) {
                if (callbacks.get(i) == EMPTY_CALLBACK) continue;
                QuicFrame lostFrame = lostPacket.getFrames().get(i);
                ((Consumer)callbacks.get(i)).accept(lostFrame);
            }
        };
    }

    protected QuicPacket createPacket(byte[] sourceConnectionId, byte[] destinationConnectionId, QuicFrame frame) {
        QuicPacket packet;
        switch (this.level) {
            case Handshake: {
                packet = new HandshakePacket(this.quicVersion, sourceConnectionId, destinationConnectionId, frame);
                break;
            }
            case App: {
                packet = new ShortHeaderPacket(this.quicVersion, destinationConnectionId, frame);
                break;
            }
            case ZeroRTT: {
                packet = new ZeroRttPacket(this.quicVersion, sourceConnectionId, destinationConnectionId, frame);
                break;
            }
            default: {
                throw new RuntimeException();
            }
        }
        packet.setPacketNumber(this.nextPacketNumber());
        return packet;
    }
}

