package io.dyte.media.hive

import io.dyte.webrtc.*
import io.dyte.media.hive.*
import io.dyte.media.hive.handlers.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.serialization.json.Json

const val REASON_TRANSPORT_CLOSED = "transport closed"

open class HiveTransportOptions(
  open val iceServers: List<IceServer>? = null,
  open val iceTransportPolicy: IceTransportPolicy? = null,
  open val additionalSettings: MutableMap<String, Any>? = null,
  open val proprietaryConstraints: Any? = null,
  open val appData: Map<String, Any>? = null
)

class HiveInternalTransportOptions(
  val id: String? = null,
  val direction: RtpTransceiverDirection,
  val handlerFactory: HiveHandlerFactory? = null,

  override val iceServers: List<IceServer>,
  override val iceTransportPolicy: IceTransportPolicy,
  override val additionalSettings: MutableMap<String, Any>?,
  override val proprietaryConstraints: Any?,
  override val appData: Map<String, Any>?
): HiveTransportOptions(iceServers, iceTransportPolicy, additionalSettings, proprietaryConstraints, appData)

enum class HiveConnectionState {
  New,
  Connecting,
  Connected,
  Failed,
  Disconnected,
  Closed
}

open class HiveConsumerStateObject(
  open val consumerId: String,
  open val trackId: String,
  open val streamId: String,
  open val screenShare: Boolean,
  open val paused: Boolean,
  open val kind: MediaStreamTrackKind
)

class HiveConsumerCreationTaskOptions(
  override val consumerId: String,
  override val trackId: String,
  override val streamId: String,
  override val screenShare: Boolean,
  override val paused: Boolean,
  override val kind: MediaStreamTrackKind,

  val producerId: String,
  val producingPeerId: String,
  val appData: MutableMap<String, Any>
): HiveConsumerStateObject(consumerId, trackId, streamId, screenShare, paused, kind)

class HiveConsumerCreationTaskException(
  val options: HiveConsumerCreationTaskOptions,
  var isTimedOut: Boolean = false,
  val name: String = "Consumer Creation Task Exception",
  override val message: String = "Consumer Creation Failed"
): Exception(message)

class HiveTransport (options: HiveInternalTransportOptions) {
  /** Id */
  private val _id = options.id

  private lateinit var _serverId: String

  /** Closed flag */
  private var _closed = false

  /** Direction */
  private val _direction = options.direction

  /** SCTP max message size if enabled, null otherwise. */
//  private val _maxSctpMessageSize: Long [Unused]

  /** RTC handler instance */
//  private val _handler = options.handlerFactory // check
  private lateinit var _handler: HiveUnifiedPlan

  /** Transport connection state */
  private var _connectionState = HiveConnectionState.New

  /** Producers map */
  private var _producers = mutableMapOf<String, HiveProducer>()

  /** Consumers map */
  private var _consumers = mutableMapOf<String, HiveConsumer>()

  private var _connected = false

  private val _transportConnection = CompletableDeferred<Boolean>()

  @OptIn(ExperimentalCoroutinesApi::class)
  val dispatcher = Dispatchers.Default.limitedParallelism(1)

  val observer = MutableSharedFlow<HiveEmitData>()

  val externalObserver = MutableSharedFlow<HiveEmitData>()

  // TODO: Fix this
  private var consumerTrackEvents = mutableMapOf<String, consumerTrackEvent>()

  private var unknownTracksMap = mutableMapOf<String, TrackEvent>()

  /** App custom data */
  private val _appData: Map<String, Any> = options.appData ?: emptyMap()

  private var dataChannelCache = mutableMapOf<String, MutableList<DCMessageChunked>>()

  init {
    runBlocking {
      _handler.init(HiveHandlerRunOptions(
        direction = options.direction,
        iceServers = options.iceServers,
        iceTransportPolicy = options.iceTransportPolicy,
        additionalSettings = options.additionalSettings,
        proprietaryConstraints = options.proprietaryConstraints,
        onTrackHandler = ::_onTrack
      ))
    }

    @Suppress("DeferredResultUnused")
    CoroutineScope(Dispatchers.Default).async {
      observer.collect {
        when (it.eventName) {
          "connected" -> _transportConnection.complete(true)
          "disconnect" -> _transportConnection.complete(false)
          "close" -> _transportConnection.complete(false)
        }
      }
    }

    @Suppress("DeferredResultUnused")
    CoroutineScope(Dispatchers.Default).async {
      _handler.observer.collect {
        when (it.eventName) {
          "@connectionstatechange" -> {
            val connectionState = it.data as HiveConnectionState

            if (connectionState == _connectionState) return@collect

            println("DyteMediaClient: Transport: Connection state changed to ${connectionState.name}")

            _connectionState = connectionState

            when (connectionState) {
              HiveConnectionState.Connected -> {
                _connected = true
                observer.emit(HiveEmitData("connected"))
              }
              HiveConnectionState.Disconnected -> {
                _connected = false
                observer.emit(HiveEmitData("disconnected"))
              }
              HiveConnectionState.Failed -> {
                _connected = false
                observer.emit(HiveEmitData("close"))
              }
              else -> {}
            }

            if (!_closed) observer.emit(
              HiveEmitData(
              eventName = "connectionstatechange",
              data = connectionState
            )
            )
          }

          "@icecandidate" -> {
            if (_closed) return@collect

            observer.emit(
              HiveEmitData(
              eventName = "icecandidate",
              data = it.data as IceCandidate
            )
            )
          }

          "datachannel" -> {
            @Suppress("UNCHECKED_CAST")
            val data = it.data as Map<String, Any>

            val dcmsgstr = (data["msg"] as ByteArray).decodeToString()

            val dcmsg = Json.decodeFromString(DCMessageChunked.serializer(), dcmsgstr)

            // The message is received in chunks, so we need to cache it until we have all the chunks.
            // First check if we have a cache entry for this message id. If we don't,
            if (dataChannelCache.contains(dcmsg.id)) dataChannelCache[dcmsg.id] = mutableListOf()

            // Add the chunk to the cache
            dataChannelCache[dcmsg.id]?.add(dcmsg)

            // Check if we have all the chunks
            if (dataChannelCache[dcmsg.id]?.size == dcmsg.count) {
              // We have all the chunks, so we can reassemble the message
              val chunks = dataChannelCache[dcmsg.id]
              val message = chunks?.fold("") {
                acc, dcMessageChunked -> acc + dcMessageChunked.chunk
              }

              // Delete the cache entry
              dataChannelCache.remove(dcmsg.id)

              // The message itself is a JSON object, so we need to parse it
              try {
                val parsedMessage = Json.decodeFromString(DCMessage.serializer(), message!!)

                val channel = data["channel"] as DataChannel

                observer.emit(HiveEmitData(
                  eventName = "datachannel",
                  data = mapOf(
                    "channel" to channel, // get label from channel.label
                    "parsedMessage" to parsedMessage,
                  )
                ))
              } catch (e: Error) {
                println("DyteMediaClient: Transport: Error parsing message - $e")
              }
            }
          }
        }
      }
    }
  }

  /** Transport Id */
  fun getId() = this._id

  fun getServerId() = this._serverId

  fun getConnected() = this._connected

  fun getIsConnected() = this._transportConnection

  /** Whether the Transport is closed. */
  fun getClosed() = this._closed

  /** Transport direction **/
  fun getDirection() = this._direction

  /** RTC Handler instance*/
  fun getHandler() = this._handler

  /** Connection state */
  fun getConnectionState() = this._connectionState

  /** Custom data */
  fun getAppData() = this._appData

//  fun setAppData() = throw Error("Cannot override appData object")

  fun setServerId(id: String) {
    this._serverId = id
  }

  /** Close the transport */
  suspend fun close() {
    if (this._closed) return

    println("DyteMediaClient: Transport: close()")

    this._connected = false
    this._closed = true

    // Close the handler
    this._handler.close()

    // Close all producers
    this._producers.values.forEach {
      it.close(REASON_TRANSPORT_CLOSED)
    }
    this._producers.clear()

    // Close all consumers
    this._consumers.values.forEach {
      it.close(REASON_TRANSPORT_CLOSED)
    }
    this._consumers.clear()

    observer.emit(HiveEmitData("close"))
  }

  /** Get associated Transport (RTCPeerConnection) stats */
  suspend fun getStats(): RtcStatsReport {
    if (this._closed) throw IllegalStateException("closed")

    return this._handler.getTransportStats()
  }

  private suspend fun getExternalResult(emitData: HiveEmitData) {
    this.externalObserver.collect {
      if (it.eventName == emitData.eventName) {
        emitData.data = it.data
      }
    }
  }

  suspend fun connect() {
    println("DyteMediaClient: Connecting transport: $this.id")

    try {
      val connectResult = this._handler.connect()

      // TODO: Check how to receive answer
      this.observer.emit(HiveEmitData("connect", connectResult.offerSdp))

      val externalResult = HiveEmitData("returnConnect")
      getExternalResult(externalResult)

      // call callback on answer
      connectResult.callback.invoke(externalResult.data as SessionDescription)
    } catch (e: Error) {
      println("DyteMediaClient: Transport: Failed to connect - $e")
    }
  }

  /** Restart ICE connection */
  suspend fun restartIce(): HiveGenericHandlerResult {
    println("DyteMediaClient: Transport: restartIce()")

    if (this._closed) throw IllegalStateException("closed")

    return this._handler.restartIce()
  }

  /** Update ICE servers */
  suspend fun updateIceServers(iceServers: List<IceServer>) {
    println("DyteMediaClient: Transport: updateIceServers()")

    if (this._closed) throw IllegalStateException("closed")

    this._handler.updateIceServers(iceServers)
  }

  private suspend fun _handleProducer(producer: HiveProducer) {
    producer.observer.collect {
      if (it.eventName == "close")  this._producers.remove(producer.getId())
    }
  }

  /** Create a producer */
  suspend fun produce(options: HiveProducerOptions): HiveProducer {
    println("DyteMediaClient: Transport: produce() with track = ${options.track?.id}")

    if (options.track == null) throw Error("TypeError: Missing Track")
    else if (this._direction != RtpTransceiverDirection.SendOnly)
      throw UnsupportedOperationException("Not a sending transport")
    else if (options.track.readyState is MediaStreamTrackState.Ended) throw IllegalStateException("Track ended")

    if (!this.getIsConnected().await()) throw Error("Transport not connected")

    lateinit var producerId: String
    lateinit var localId: String

    // TODO: Fix with List<CompletableDeferred> instead of Dispatcher
    dispatcher.run {
      // First we generate offer SDP
      val sendResult = _handler.send(HiveHandlerSendOptions(
        track = options.track,
        encodings = listOf(options.encodings!!),
        codecOptions = options.codecOptions,
        screenShare = options.appData?.get("screenShare") as Boolean
      ))

      // Then we send this offer to the server
      observer.emit(HiveEmitData(
        eventName = "produce",
        data = mapOf(
          "offer" to sendResult.offerSdp,
          "kind" to options.track.kind,
          "paused" to if (options.disableTrackOnPause!!) !options.track.enabled else false,
          "appData" to options.appData
        )
      ))

      val externalResult = HiveEmitData("returnProduce")
      getExternalResult(externalResult)

      @Suppress("UNCHECKED_CAST")
      val data = externalResult.data as Map<String, Any>

      val answer = data["answer"] as SessionDescription
      producerId = data["producerId"] as String

      // Then we set the answer on remote and get the localId
      localId = sendResult.callback(answer) as String
    }

    val producer = HiveProducer(HiveInternalProducerOptions(
      id = producerId,
      localId = localId,
      track = options.track,
      stopTracks = options.stopTracks!!,
      disableTrackOnPause = options.disableTrackOnPause!!,
      zeroRtpOnPause = options.zeroRtpOnPause!!,
      appData = options.appData!!,
      handler = this.getHandler()
    ))

    this._producers[producerId] = producer

    _handleProducer(producer)

    this.observer.emit(HiveEmitData("newproducer", producer))

    return producer
  }

  // Note(anunaym14): Use deferredResult.getCompleted() for successful consumer creations
  // And, deferredResult.getCompletionExceptionOrNull() for unsuccessful ones to retry
  // the consumer creation task if required
  suspend fun consumePeer(producingPeerId: String, appData: Map<String, Any>): List<CompletableDeferred<HiveConsumer>> {
    println("DyteMediaClient: Transport: consumePeer() with producingPeerId = $producingPeerId")

    if (this._closed) throw IllegalStateException("closed")
    else if (this._direction != RtpTransceiverDirection.RecvOnly)
      throw UnsupportedOperationException("Not a receiving transport")

    if (!this.getIsConnected().await()) throw Error("Transport not connected")

    val deferredResults = mutableListOf<CompletableDeferred<HiveConsumer>>()

    this.observer.emit(HiveEmitData(
      eventName = "consumePeer",
      data = producingPeerId
    ))

    val externalResult = HiveEmitData("returnConsumePeer")
    getExternalResult(externalResult)

    @Suppress("UNCHECKED_CAST")
    val consumersMap = externalResult.data as Map<String, HiveConsumerStateObject>

    consumersMap.forEach {
      deferredResults.add(
        this._consumerCreationTask(HiveConsumerCreationTaskOptions(
          consumerId = it.value.consumerId,
          trackId = it.value.trackId,
          streamId = it.value.streamId,
          kind = it.value.kind,
          producerId = it.key,
          producingPeerId = producingPeerId,
          paused = it.value.paused,
          screenShare = it.value.screenShare,
          appData = appData.toMutableMap()
        ))
      )
    }

    println("DyteMediaClient: Transport: consumePeer(): DeferredResults = $deferredResults")

    return deferredResults
  }

  suspend fun consume(options: HiveConsumerOptions): CompletableDeferred<HiveConsumer> {
    if (this._closed) throw IllegalStateException("closed")
    else if (this._direction != RtpTransceiverDirection.RecvOnly)
      throw UnsupportedOperationException("Not a receiving transport")

    if (!this.getIsConnected().await()) throw Error("Transport not connected")

    try {
      this.observer.emit(HiveEmitData(
        eventName = "consume",
        data = mapOf(
          "producerId" to options.producerId,
          "producingPeerId" to options.producingPeerId
        )
      ))

      val externalResult = HiveEmitData("returnConsume")
      getExternalResult(externalResult)

      val consumerObject = externalResult.data as HiveConsumerStateObject

      return this._consumerCreationTask(HiveConsumerCreationTaskOptions(
        consumerId = consumerObject.consumerId,
        screenShare = consumerObject.screenShare,
        trackId = consumerObject.trackId,
        streamId = consumerObject.streamId,
        kind = consumerObject.kind,
        paused = consumerObject.paused,
        producerId = options.producerId,
        producingPeerId = options.producingPeerId,
        appData = options.appData!!.toMutableMap()
      ))
    } catch (e: Error) {
      println("DyteMediaClient: Transport: Consume failed with error = $e")

      throw e
    }
  }

  private suspend fun _handleConsumer(consumer: HiveConsumer) {
    consumer.observer.collect {
      if (it.eventName == "close") {
        this._consumers.remove(consumer.getId())
        this._handler.mapMidTransceiver.remove(consumer.getLocalId()) // transceiver.mid
      }
    }
  }

  private suspend fun _consumerCreationTask(options: HiveConsumerCreationTaskOptions):
          CompletableDeferred<HiveConsumer> {
    val key = "${options.streamId}:${options.kind}"
    val exception = HiveConsumerCreationTaskException(options)

    val deferredConsumer = CompletableDeferred<HiveConsumer>()

    val timeoutTimer = CoroutineScope(Dispatchers.Default).launch {
      delay(5000)
      consumerTrackEvents.remove(key)
      exception.isTimedOut = true
      deferredConsumer.completeExceptionally(exception)
    }

    val consumeHandler: suspend (TrackEvent) -> Unit = { event ->
      try {
        if (event.track?.readyState is MediaStreamTrackState.Ended) {
          timeoutTimer.cancel()
          deferredConsumer.completeExceptionally(exception)
        } else {
          val transceiver = event.transceiver

          this._handler.mapMidTransceiver[transceiver.mid] = transceiver

          // TODO: Check screenShare stuff
          options.appData["screenShare"] = options.screenShare

          val consumer = HiveConsumer(HiveInternalConsumerOptions(
            id = options.consumerId,
            localId = event.transceiver.mid,
            track = event.track,
            kind = event.track?.kind!!,
            paused = options.paused,
            producerId = options.producerId,
            producingPeerId = options.producingPeerId,
            handler = this._handler,
            appData = options.appData,
            //screenShare = options.screenShare
          ))

          this._consumers[options.consumerId] = consumer

          _handleConsumer(consumer)

          println("Consumer created for producerId = ${options.producerId} trackId = ${options.trackId} " +
                  "producingPeerId = ${options.producingPeerId}")

          this.observer.emit(HiveEmitData(
            eventName = "newconsumer",
            data = consumer
          ))

          timeoutTimer.cancel()
          deferredConsumer.complete(consumer)
        }
      } catch (e: Error) {
        println("Error while creating consumer: $e")
        timeoutTimer.cancel()
        deferredConsumer.completeExceptionally(exception)
      }
    }

    val existingTrack = this.unknownTracksMap[key]

    if (existingTrack != null) {
      this.unknownTracksMap.remove(key)
      consumeHandler(existingTrack)
    } else {
      this.consumerTrackEvents[key] = consumeHandler
    }

    return deferredConsumer
  }

  private suspend fun _onTrack(event: TrackEvent) {
    println("DyteMediaClient: Transport: Track event received: $event")

    val key = "${event.streams[0].id}:${event.track?.kind}"

    val eventHandler = this.consumerTrackEvents[key]

    if (eventHandler != null) {
      // TODO: Figure out
      eventHandler(event)
      this.consumerTrackEvents.remove(key)
    } else {
      println("DyteMediaClient: Transport: Track event handler not found for key = $key")

      this.unknownTracksMap[key] = event
    }
  }

  private suspend fun setRemoteDescription(sdp: SessionDescription) =
    this._handler.getPc().setRemoteDescription(sdp)

  private suspend fun setLocalDescription(sdp: SessionDescription) {
    println("DyteMediaClient: ${this._direction}() {transportId: ${this._serverId}} | " +
            "calling pc.setLocalDescription() with offer = ${sdp.sdp}")

    this._handler.getPc().setLocalDescription(sdp)
  }

  suspend fun setRemoteOffer(offer: SessionDescription): SessionDescription {
    this.setRemoteDescription(offer)

    val ans = this._handler.getPc().createAnswer(OfferAnswerOptions())

    this.setLocalDescription(ans)

    return ans
  }

  suspend fun retryFailedConsumerCreationTasks(tasks: List<HiveConsumerCreationTaskException>):
          List<CompletableDeferred<HiveConsumer>> {
    val deferredResults = mutableListOf<CompletableDeferred<HiveConsumer>>()

    tasks.forEach {
      deferredResults.add(
        this._consumerCreationTask(it.options)
      )
    }

    return deferredResults
  }
}

typealias consumerTrackEvent = suspend (TrackEvent) -> Unit