package io.dyte.callstats.tests

import io.dyte.callstats.models.ParsedIceCandidate
import io.dyte.callstats.models.TestResult
import io.dyte.callstats.utils.DependencyProvider
import io.dyte.webrtc.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach

val noFilter: (IceCandidate) -> Boolean =
  noFilterLambda@{ _: IceCandidate ->
    return@noFilterLambda true
  }

val isHost: (IceCandidate) -> Boolean =
  isHostLambda@{ iceCandidate: IceCandidate ->
    val parsedIceCandidate = parseIceCandidate(iceCandidate)
    if (parsedIceCandidate.type == "host") {
      return@isHostLambda true
    }
    return@isHostLambda false
  }

val isNotHostCandidate: (IceCandidate) -> Boolean =
  isNotHostCandidateLambda@{ iceCandidate: IceCandidate ->
    val parsedIceCandidate = parseIceCandidate(iceCandidate)
    if (parsedIceCandidate.type != "host") {
      return@isNotHostCandidateLambda true
    }
    return@isNotHostCandidateLambda false
  }

val isRelay: (IceCandidate) -> Boolean =
  isRelayLambda@{ iceCandidate: IceCandidate ->
    val parsedIceCandidate = parseIceCandidate(iceCandidate)
    if (parsedIceCandidate.type == "relay") {
      return@isRelayLambda true
    }
    return@isRelayLambda false
  }

val isReflexive: (IceCandidate) -> Boolean =
  isReflexiveLambda@{ iceCandidate: IceCandidate ->
    val parsedIceCandidate = parseIceCandidate(iceCandidate)
    return@isReflexiveLambda parsedIceCandidate.type == "srflx"
  }

fun parseIceCandidate(iceCandidate: IceCandidate): ParsedIceCandidate {
  val sdp = iceCandidate.candidate
  val candidateStr = "candidate:"
  val pos = sdp.indexOf(candidateStr) + candidateStr.length
  val fields = sdp.substring(pos).split(" ")

  return ParsedIceCandidate(type = fields[7], protocol = fields[2], address = fields[4])
}

abstract class CallTest(
  private val config: RtcConfiguration,
  provider: DependencyProvider,
  private val filter: (IceCandidate) -> Boolean = noFilter,
  private val done: suspend (Any) -> Unit,
  private val failed: suspend (String) -> Unit,
  // We are using testType to create unique data channel labels for each one
  private val testType: String,
  private val coroutineScope: CoroutineScope,
) {
  lateinit var pc1: PeerConnection
  lateinit var pc2: PeerConnection

  lateinit var ch1: DataChannel
  lateinit var ch2: DataChannel

  val logger = provider.getLogger()

  private var iceCandidateFilter: (IceCandidate) -> Boolean = noFilter

  private lateinit var testTimerJob: Job

  private fun createDataChannels() {
    setIceCandidateFilter(filter)

    val channelLabel = "sendDataChannel_${testType}"
    ch1 = pc1.createDataChannel(channelLabel)!!
    ch1.handle(onOpenDC1, onMessageDC1)

    logger.logDebug("Data Channel 1 set!")
  }

  private val onIceCandidatePC1 = { data: IceCandidate ->
    logger.logDebug("Found ICE Candidate for PC1")
    logger.logDebug("Ice Candidate Filter: $iceCandidateFilter")

    if (iceCandidateFilter(data)) {
      pc2.addIceCandidate(data)
      logger.logDebug("Added ice candidate in PC2")
    }
  }

  private val onIceCandidatePC2 = { data: IceCandidate ->
    logger.logDebug("Found ICE Candidate for PC2")

    if (iceCandidateFilter(data)) {
      pc1.addIceCandidate(data)
      logger.logDebug("Added ice candidate in PC1")
    }
  }

  open val onMessageDC2: suspend (ByteArray) -> Unit = { buffer ->
    val msg = buffer.decodeToString()
    logger.logDebug("Received message from Data Channel 1: $msg")

    if (msg == "Namastey Duniyaa") {
      val bb = "Duniyaa".encodeToByteArray()
      logger.logDebug("Sending Duniyaa from data channel 1")
      ch2.send(bb)
    } else {
      testFailed("Invalid data transmitted, expected: \"Namastey Duniyaa\", received: $msg")
    }
  }

  private val onDataChannelPC2 = { dataChannel: DataChannel ->
    ch2 = dataChannel
    ch2.handle(null, onMessageDC2)
    logger.logDebug("Data Channel 2 set!")
  }

  suspend fun start(timeout: Long) {
    pc1 = makePeerConnection(config, onIceCandidatePC1, null)
    logger.logDebug("Made Peer Connection 1")

    pc2 = makePeerConnection(config, onIceCandidatePC2, onDataChannelPC2)
    logger.logDebug("Made Peer Connection 2")

    establishConnection(timeout)
  }

  open val onOpenDC1 = {
    logger.logDebug("Data Channel 1 Open: ${ch1.readyState}")
    val bb = "Namastey Duniyaa".encodeToByteArray()
    ch1.send(bb)
    logger.logDebug("Sending Namastey Duniyaa from data channel 1")
  }

  private val onMessageDC1: suspend (ByteArray) -> Unit = { buffer ->
    val msg = buffer.decodeToString()
    logger.logDebug("Received message from Data Channel 2: $msg")

    if (msg == "Duniyaa") {
      logger.logDebug("Data successfully transmitted between peers")
      testComplete(TestResult(connectivity = true))
    } else {
      testFailed("Invalid data transmitted, expected: Duniyaa, received: $msg")
    }
  }

  private val onCreateSuccessPC2: suspend (SessionDescription) -> Unit =
    { sessionsDesc: SessionDescription ->
      logger.logDebug("PC2 SDP on create succeeded!")
      pc1.setRemoteDescription(sessionsDesc)

      logger.logDebug("PC1 has set the remote description!")
      pc2.setLocalDescription(sessionsDesc)
      logger.logDebug("PC2 has set the local description!")
    }

  private val onCreateSuccessPC1: suspend (SessionDescription) -> Unit =
    { sessionDescription: SessionDescription ->
      // TODO: Add condition for constrainOfferToRemoveVideoFec

      logger.logDebug("PC1 SDP on create succeeded!")

      pc1.setLocalDescription(sessionDescription)
      logger.logDebug("PC1 has set the local description!")
      pc2.setRemoteDescription(sessionDescription)
      logger.logDebug("PC2 has set the remote description!")

      onCreateSuccessPC2.invoke(pc2.createAnswer(OfferAnswerOptions()))
      logger.logDebug("PC2 Creating Answer")
    }

  @OptIn(DelicateCoroutinesApi::class)
  private suspend fun establishConnection(timeout: Long = 10000) {
    createDataChannels()

    onCreateSuccessPC1.invoke(pc1.createOffer(OfferAnswerOptions()))
    logger.logDebug("PC1 Creating Offer")

    testTimerJob =
      GlobalScope.launch {
        delay(timeout)
        testFailed("Time limit exceeded for the test!")
      }
  }

  private fun makePeerConnection(
    config: RtcConfiguration,
    onIceCandidateCallback: (IceCandidate) -> Unit,
    onDataChannelCallback: ((DataChannel) -> Unit)?,
  ): PeerConnection {
    val pc = PeerConnection(config)
    pc.handle(onIceCandidateCallback, onDataChannelCallback)
    return pc
  }

  private fun setIceCandidateFilter(iceCandidateFilter: (IceCandidate) -> Boolean) {
    this.iceCandidateFilter = iceCandidateFilter
  }

  private suspend fun testFailed(error: String) {
    logger.log("Test failed: $error")
    close()
    failed(error)
  }

  suspend fun testComplete(data: Any) {
    testTimerJob.cancel()
    close()
    done(data)
  }

  private fun close() {
    pc1.close()
    pc2.close()
  }

  private fun PeerConnection.handle(
    onIceCandidateCallback: (IceCandidate) -> Unit,
    onDataChannelCallback: ((DataChannel) -> Unit)?,
  ) {
    this.onIceCandidate.onEach { onIceCandidateCallback.invoke(it) }.launchIn(coroutineScope)

    this.onDataChannel.onEach { onDataChannelCallback?.invoke(it) }.launchIn(coroutineScope)

    this.onSignalingStateChange
      .onEach { signalingState -> logger.logDebug("On Signaling Change: $signalingState") }
      .launchIn(coroutineScope)

    this.onIceConnectionStateChange
      .onEach { iceConnectionState ->
        logger.logDebug("On Ice Connection Change: $iceConnectionState")
      }
      .launchIn(coroutineScope)

    this.onIceGatheringState
      .onEach { iceGatheringState ->
        logger.logDebug("On Ice Gathering Change: $iceGatheringState")
      }
      .launchIn(coroutineScope)

    this.onRemovedIceCandidates
      .onEach { removedIceCandidates ->
        logger.logDebug("On Removed Ice Candidates: $removedIceCandidates")
      }
      .launchIn(coroutineScope)
  }

  private fun DataChannel.handle(
    onOpenCallback: (() -> Unit)?,
    onMessageCallback: suspend (ByteArray) -> Unit,
  ) {
    this.onMessage.onEach { onMessageCallback.invoke(it) }.launchIn(coroutineScope)

    this.onOpen.onEach { onOpenCallback?.invoke() }.launchIn(coroutineScope)
  }
}

class HostConnectivityTest(
  config: RtcConfiguration,
  provider: DependencyProvider,
  done: suspend (Any) -> Unit,
  failed: suspend (String) -> Unit,
  coroutineScope: CoroutineScope,
) : CallTest(config, provider, isHost, done, failed, "HOST_CONNECTIVITY_TEST", coroutineScope)

class RelayConnectivityTest(
  config: RtcConfiguration,
  provider: DependencyProvider,
  done: suspend (Any) -> Unit,
  failed: suspend (String) -> Unit,
  coroutineScope: CoroutineScope,
) : CallTest(config, provider, isRelay, done, failed, "RELAY_CONNECTIVITY_TEST", coroutineScope)

class ReflexiveConnectivityTest(
  config: RtcConfiguration,
  provider: DependencyProvider,
  done: suspend (Any) -> Unit,
  failed: suspend (String) -> Unit,
  coroutineScope: CoroutineScope,
) :
  CallTest(
    config,
    provider,
    isReflexive,
    done,
    failed,
    "REFLEXIVE_CONNECTIVITY_TEST",
    coroutineScope
  )
