/*
 * Decompiled with CFR 0.152.
 */
package net.luminis.quic.recovery;

import java.time.Duration;
import java.time.Instant;
import java.time.LocalTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.luminis.quic.EncryptionLevel;
import net.luminis.quic.FrameProcessor2;
import net.luminis.quic.FrameProcessorRegistry;
import net.luminis.quic.HandshakeState;
import net.luminis.quic.HandshakeStateListener;
import net.luminis.quic.PnSpace;
import net.luminis.quic.Role;
import net.luminis.quic.cc.CongestionController;
import net.luminis.quic.concurrent.DaemonThreadFactory;
import net.luminis.quic.frame.AckFrame;
import net.luminis.quic.frame.Padding;
import net.luminis.quic.frame.PingFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.packet.QuicPacket;
import net.luminis.quic.recovery.LossDetector;
import net.luminis.quic.recovery.RttEstimator;
import net.luminis.quic.send.Sender;

public class RecoveryManager
implements FrameProcessor2<AckFrame>,
HandshakeStateListener {
    private final Role role;
    private final RttEstimator rttEstimater;
    private final LossDetector[] lossDetectors = new LossDetector[PnSpace.values().length];
    private final Sender sender;
    private final Logger log;
    private final ScheduledExecutorService scheduler;
    private int receiverMaxAckDelay;
    private volatile ScheduledFuture<?> lossDetectionTimer;
    private volatile int ptoCount;
    private volatile Instant timerExpiration;
    private volatile HandshakeState handshakeState = HandshakeState.Initial;
    private volatile boolean hasBeenReset = false;

    public RecoveryManager(FrameProcessorRegistry processorRegistry, Role role, RttEstimator rttEstimater, CongestionController congestionController, Sender sender, Logger logger) {
        this.role = role;
        this.rttEstimater = rttEstimater;
        for (PnSpace pnSpace : PnSpace.values()) {
            this.lossDetectors[pnSpace.ordinal()] = new LossDetector(this, rttEstimater, congestionController);
        }
        this.sender = sender;
        this.log = logger;
        processorRegistry.registerProcessor(this);
        this.scheduler = Executors.newScheduledThreadPool(1, new DaemonThreadFactory("loss-detection"));
        this.lossDetectionTimer = new NullScheduledFuture();
    }

    void setLossDetectionTimer() {
        Instant lossTime;
        PnSpaceTime earliestLossTime = this.getEarliestLossTime(LossDetector::getLossTime);
        Instant instant = lossTime = earliestLossTime != null ? earliestLossTime.lossTime : null;
        if (lossTime != null) {
            this.lossDetectionTimer.cancel(false);
            int timeout = (int)Duration.between(Instant.now(), lossTime).toMillis();
            this.lossDetectionTimer = this.reschedule(() -> this.lossDetectionTimeout(), timeout);
        } else {
            boolean ackElicitingInFlight = this.ackElicitingInFlight();
            boolean peerAwaitingAddressValidation = this.peerAwaitingAddressValidation();
            if (ackElicitingInFlight || peerAwaitingAddressValidation) {
                PnSpaceTime ptoTimeAndSpace = this.getPtoTimeAndSpace();
                if (ptoTimeAndSpace.lossTime.equals(Instant.MAX)) {
                    this.log.recovery("cancelling loss detection timer (no loss time set, no ack eliciting in flight for I/H, peer not awaiting address validation)");
                    this.unschedule();
                } else {
                    int timeout = (int)Duration.between(Instant.now(), ptoTimeAndSpace.lossTime).toMillis();
                    if (timeout < 1) {
                        timeout = 0;
                    }
                    this.log.recovery("reschedule loss detection timer for PTO over " + timeout + " millis, based on %s/" + ptoTimeAndSpace.pnSpace + ", because " + (peerAwaitingAddressValidation ? "peerAwaitingAddressValidation " : "") + (ackElicitingInFlight ? "ackElicitingInFlight " : "") + "| RTT:" + this.rttEstimater.getSmoothedRtt() + "/" + this.rttEstimater.getRttVar(), ptoTimeAndSpace.lossTime);
                    this.lossDetectionTimer.cancel(false);
                    this.lossDetectionTimer = this.reschedule(() -> this.lossDetectionTimeout(), timeout);
                }
            } else {
                this.log.recovery("cancelling loss detection timer (no loss time set, no ack eliciting in flight, peer not awaiting address validation)");
                this.unschedule();
            }
        }
    }

    private PnSpaceTime getPtoTimeAndSpace() {
        int ptoDuration = this.rttEstimater.getSmoothedRtt() + Integer.max(1, 4 * this.rttEstimater.getRttVar());
        ptoDuration *= (int)Math.pow(2.0, this.ptoCount);
        if (!this.ackElicitingInFlight()) {
            if (this.handshakeState.hasNoHandshakeKeys()) {
                this.log.info("getPtoTimeAndSpace: no ack eliciting in flight and no handshake keys -> I");
                return new PnSpaceTime(PnSpace.Initial, Instant.now().plusMillis(ptoDuration));
            }
            this.log.info("getPtoTimeAndSpace: no ack eliciting in flight and but handshake keys -> H");
            return new PnSpaceTime(PnSpace.Handshake, Instant.now().plusMillis(ptoDuration));
        }
        Instant ptoTime = Instant.MAX;
        PnSpace ptoSpace = null;
        for (PnSpace pnSpace : PnSpace.values()) {
            Instant lastAckElicitingSent;
            if (!this.lossDetectors[pnSpace.ordinal()].ackElicitingInFlight()) continue;
            if (pnSpace == PnSpace.App && this.handshakeState.isNotConfirmed()) {
                this.log.recovery("getPtoTimeAndSpace is skipping level App, because handshake not yet confirmed!");
                continue;
            }
            if (pnSpace == PnSpace.App) {
                ptoDuration += this.receiverMaxAckDelay * (int)Math.pow(2.0, this.ptoCount);
            }
            if (!(lastAckElicitingSent = this.lossDetectors[pnSpace.ordinal()].getLastAckElicitingSent()).plusMillis(ptoDuration).isBefore(ptoTime)) continue;
            ptoTime = lastAckElicitingSent.plusMillis(ptoDuration);
            ptoSpace = pnSpace;
        }
        return new PnSpaceTime(ptoSpace, ptoTime);
    }

    private boolean peerAwaitingAddressValidation() {
        return this.role == Role.Client && this.handshakeState.isNotConfirmed() && this.lossDetectors[PnSpace.Handshake.ordinal()].noAckedReceived();
    }

    private void lossDetectionTimeout() {
        Instant lossTime;
        Instant expiration = this.timerExpiration;
        if (expiration == null) {
            this.log.warn("Loss detection timeout: Timer was cancelled.");
            return;
        }
        if (Instant.now().isBefore(expiration)) {
            this.log.warn("Scheduled task running early: " + Duration.between(Instant.now(), expiration) + "(" + expiration + ")");
            long remainingWaitTime = Duration.between(Instant.now(), expiration).toMillis() + 1L;
            if (remainingWaitTime > 0L) {
                try {
                    Thread.sleep(remainingWaitTime);
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
            }
            if ((expiration = this.timerExpiration) == null) {
                this.log.warn("Delayed task: timer expiration is now null, cancelled");
                return;
            }
            if (Instant.now().isBefore(expiration)) {
                this.log.warn("Delayed task is now still before timer expiration, probably rescheduled in the meantime; " + Duration.between(Instant.now(), this.timerExpiration) + "(" + this.timerExpiration + ")");
                return;
            }
            this.log.warn("Delayed task running now");
        } else {
            this.log.recovery("%s loss detection timeout handler running", Instant.now());
        }
        PnSpaceTime earliestLossTime = this.getEarliestLossTime(LossDetector::getLossTime);
        Instant instant = lossTime = earliestLossTime != null ? earliestLossTime.lossTime : null;
        if (lossTime != null) {
            this.lossDetectors[earliestLossTime.pnSpace.ordinal()].detectLostPackets();
            this.sender.flush();
            this.setLossDetectionTimer();
        } else {
            this.sendProbe();
        }
    }

    private void sendProbe() {
        int nrOfProbes;
        PnSpaceTime earliestLastAckElicitingSentTime = this.getEarliestLossTime(LossDetector::getLastAckElicitingSent);
        if (earliestLastAckElicitingSentTime != null) {
            this.log.recovery(String.format("Sending probe %d, because no ack since %%s. Current RTT: %d/%d.", this.ptoCount, this.rttEstimater.getSmoothedRtt(), this.rttEstimater.getRttVar()), earliestLastAckElicitingSentTime.lossTime);
        } else {
            this.log.recovery(String.format("Sending probe %d. Current RTT: %d/%d.", this.ptoCount, this.rttEstimater.getSmoothedRtt(), this.rttEstimater.getRttVar()));
        }
        ++this.ptoCount;
        int n = nrOfProbes = this.ptoCount > 1 ? 2 : 1;
        if (this.ackElicitingInFlight()) {
            PnSpaceTime ptoTimeAndSpace = this.getPtoTimeAndSpace();
            this.sendOneOrTwoAckElicitingPackets(ptoTimeAndSpace.pnSpace, nrOfProbes);
        } else {
            this.log.recovery("Sending probe because peer awaiting address validation");
            if (this.handshakeState.hasNoHandshakeKeys()) {
                this.sendOneOrTwoAckElicitingPackets(PnSpace.Initial, 1);
            } else {
                this.sendOneOrTwoAckElicitingPackets(PnSpace.Handshake, 1);
            }
        }
    }

    private void sendOneOrTwoAckElicitingPackets(PnSpace pnSpace, int numberOfPackets) {
        if (pnSpace == PnSpace.Initial) {
            List<QuicFrame> framesToRetransmit = this.getFramesToRetransmit(PnSpace.Initial);
            if (!framesToRetransmit.isEmpty()) {
                this.log.recovery("(Probe is an initial retransmit)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(framesToRetransmit, EncryptionLevel.Initial));
            } else {
                this.log.recovery("(Probe is Initial ping, because there is no Initial data to retransmit)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(List.of(new PingFrame(), new Padding(2)), EncryptionLevel.Initial));
            }
        } else if (pnSpace == PnSpace.Handshake) {
            List<QuicFrame> framesToRetransmit = this.getFramesToRetransmit(PnSpace.Handshake);
            if (!framesToRetransmit.isEmpty()) {
                this.log.recovery("(Probe is a handshake retransmit)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(framesToRetransmit, EncryptionLevel.Handshake));
            } else {
                this.log.recovery("(Probe is a handshake ping)");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(List.of(new PingFrame(), new Padding(2)), EncryptionLevel.Handshake));
            }
        } else {
            EncryptionLevel probeLevel = pnSpace.relatedEncryptionLevel();
            List<QuicFrame> framesToRetransmit = this.getFramesToRetransmit(pnSpace);
            if (!framesToRetransmit.isEmpty()) {
                this.log.recovery("(Probe is retransmit on level " + probeLevel + ")");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(framesToRetransmit, probeLevel));
            } else {
                this.log.recovery("(Probe is ping on level " + probeLevel + ")");
                this.repeatSend(numberOfPackets, () -> this.sender.sendProbe(List.of(new PingFrame(), new Padding(2)), probeLevel));
            }
        }
    }

    List<QuicFrame> getFramesToRetransmit(PnSpace pnSpace) {
        List<QuicPacket> unAckedPackets = this.lossDetectors[pnSpace.ordinal()].unAcked();
        Optional<QuicPacket> ackEliciting = unAckedPackets.stream().filter(p -> p.isAckEliciting()).filter(p -> !p.getFrames().stream().allMatch(frame -> frame instanceof PingFrame || frame instanceof Padding || frame instanceof AckFrame)).findFirst();
        if (ackEliciting.isPresent()) {
            List<QuicFrame> framesToRetransmit = ackEliciting.get().getFrames().stream().filter(frame -> !(frame instanceof AckFrame)).collect(Collectors.toList());
            return framesToRetransmit;
        }
        return Collections.emptyList();
    }

    PnSpaceTime getEarliestLossTime(Function<LossDetector, Instant> pnSpaceTimeFunction) {
        PnSpaceTime earliestLossTime = null;
        for (PnSpace pnSpace : PnSpace.values()) {
            Instant pnSpaceLossTime = pnSpaceTimeFunction.apply(this.lossDetectors[pnSpace.ordinal()]);
            if (pnSpaceLossTime == null) continue;
            if (earliestLossTime == null) {
                earliestLossTime = new PnSpaceTime(pnSpace, pnSpaceLossTime);
                continue;
            }
            if (earliestLossTime.lossTime.isBefore(pnSpaceLossTime)) continue;
            earliestLossTime = new PnSpaceTime(pnSpace, pnSpaceLossTime);
        }
        return earliestLossTime;
    }

    ScheduledFuture<?> reschedule(Runnable runnable, int timeout) {
        if (!this.lossDetectionTimer.cancel(false)) {
            this.log.debug("Cancelling loss detection timer failed");
        }
        this.timerExpiration = Instant.now().plusMillis(timeout);
        return this.scheduler.schedule(() -> {
            try {
                runnable.run();
            }
            catch (Exception error) {
                this.log.error("Runtime exception occurred while processing scheduled task", error);
            }
        }, (long)timeout, TimeUnit.MILLISECONDS);
    }

    void unschedule() {
        this.lossDetectionTimer.cancel(true);
        this.timerExpiration = null;
    }

    public void onAckReceived(AckFrame ackFrame, PnSpace pnSpace, Instant timeReceived) {
        if (!this.hasBeenReset) {
            if (this.ptoCount > 0) {
                if (!this.peerAwaitingAddressValidation()) {
                    this.ptoCount = 0;
                } else {
                    this.log.recovery("probe count not reset on ack because handshake not yet confirmed");
                }
            }
            this.lossDetectors[pnSpace.ordinal()].onAckReceived(ackFrame, timeReceived);
        }
    }

    public void packetSent(QuicPacket packet, Instant sent, Consumer<QuicPacket> packetLostCallback) {
        if (!this.hasBeenReset && packet.isInflightPacket()) {
            this.lossDetectors[packet.getPnSpace().ordinal()].packetSent(packet, sent, packetLostCallback);
            this.setLossDetectionTimer();
        }
    }

    private boolean ackElicitingInFlight() {
        return Stream.of(this.lossDetectors).anyMatch(detector -> detector.ackElicitingInFlight());
    }

    public synchronized void setReceiverMaxAckDelay(int receiverMaxAckDelay) {
        this.receiverMaxAckDelay = receiverMaxAckDelay;
    }

    public void stopRecovery() {
        if (!this.hasBeenReset) {
            this.hasBeenReset = true;
            this.unschedule();
            this.scheduler.shutdown();
            for (PnSpace pnSpace : PnSpace.values()) {
                this.lossDetectors[pnSpace.ordinal()].reset();
            }
        }
    }

    public void stopRecovery(PnSpace pnSpace) {
        if (!this.hasBeenReset) {
            this.lossDetectors[pnSpace.ordinal()].reset();
            this.ptoCount = 0;
            this.setLossDetectionTimer();
        }
    }

    public long getLost() {
        return Stream.of(this.lossDetectors).mapToLong(ld -> ld.getLost()).sum();
    }

    @Override
    public void handshakeStateChangedEvent(HandshakeState newState) {
        if (!this.hasBeenReset) {
            HandshakeState oldState = this.handshakeState;
            this.handshakeState = newState;
            if (newState == HandshakeState.Confirmed && oldState != HandshakeState.Confirmed) {
                this.log.recovery("State is set to " + newState);
                this.setLossDetectionTimer();
            }
        }
    }

    @Override
    public void process(AckFrame frame, PnSpace pnSpace, Instant timeReceived) {
        this.onAckReceived(frame, pnSpace, timeReceived);
    }

    private void repeatSend(int count, Runnable task) {
        for (int i = 0; i < count; ++i) {
            task.run();
            try {
                Thread.sleep(1L);
                continue;
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
        }
    }

    String timeNow() {
        LocalTime localTimeNow = LocalTime.from(Instant.now().atZone(ZoneId.systemDefault()));
        DateTimeFormatter timeFormatter = DateTimeFormatter.ofPattern("mm:ss.SSS");
        return timeFormatter.format(localTimeNow);
    }

    static class PnSpaceTime {
        public PnSpace pnSpace;
        public Instant lossTime;

        public PnSpaceTime(PnSpace pnSpace, Instant pnSpaceLossTime) {
            this.pnSpace = pnSpace;
            this.lossTime = pnSpaceLossTime;
        }

        public String toString() {
            return this.lossTime.toString() + " (in " + this.pnSpace + ")";
        }
    }

    private static class NullScheduledFuture
    implements ScheduledFuture<Void> {
        private NullScheduledFuture() {
        }

        @Override
        public int compareTo(Delayed o) {
            return 0;
        }

        @Override
        public long getDelay(TimeUnit unit) {
            return 0L;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return false;
        }

        @Override
        public boolean isCancelled() {
            return false;
        }

        @Override
        public boolean isDone() {
            return false;
        }

        @Override
        public Void get() throws InterruptedException, ExecutionException {
            return null;
        }

        @Override
        public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            return null;
        }
    }
}

