package io.dyte.sockrates.client

import io.dyte.sockrates.client.backoff.Backoff
import io.dyte.sockrates.client.backoff.ExponentialBackoff
import io.ktor.client.*
import io.ktor.client.plugins.ResponseException as KtorResponseException
import io.ktor.client.plugins.websocket.*
import io.ktor.http.*
import io.ktor.websocket.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.awaitCancellation
import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import message.v1.SocketMessage
import okio.ByteString.Companion.toByteString
import kotlin.DeprecationLevel.HIDDEN

/*
All the public methods of Sockrates are directly calling KTor's methods instead of
switching context (Dispatcher), as all the methods inside KTor are main-safe and internally
take care of switching the threads.
*/
class Sockrates(
  private val url: String,
  private val config: SockratesConfiguration = SockratesConfiguration(),
  private val backoff: Backoff = ExponentialBackoff(addJitter = true)
) {
  private val wsClient = HttpClient {
    install(WebSockets)
  }

  private var wsListener: SockratesWSListener? = null

  private var wsSession: DefaultClientWebSocketSession? = null

  /*
  * note: According to the Websockets spec, the initial value must be CONNECTING but we have explicitly made it CLOSED.
  * https://websockets.spec.whatwg.org/#websocket-ready-state
  * */
  private var readyState: WebSocketReadyState = WebSocketReadyState.CLOSED // Need to expose this as read-only property

  private var connectionState: WebSocketConnectionState = WebSocketConnectionState.Disconnected(
    CloseReason.Codes.NORMAL.code.toInt(),
    "Client has not started connection yet"
  )

  private var reconnecting: Boolean = false

  val connected: Boolean
    get() = readyState == WebSocketReadyState.OPEN

  private val activeEventListeners: MutableMap<Int, MutableSet<SockratesEventListener>> = HashMap()

  private val scope: CoroutineScope = CoroutineScope(Dispatchers.Default + SupervisorJob())

  private var readMessagesJob: Job? = null
  private var pingTimeoutJob: Job? = null
  private var reconnectionJob: Job? = null

  suspend fun connect(wsListener: SockratesWSListener): SockratesResult<Unit, ConnectFailureReason> {
    if (readyState == WebSocketReadyState.OPEN) {
      return SockratesResult.Failure(ConnectFailureReason.SocketAlreadyConnected(url))
    }

    if (readyState == WebSocketReadyState.CONNECTING || reconnecting) {
      return SockratesResult.Failure(ConnectFailureReason.SocketAlreadyConnecting(url))
    }

    if (readyState == WebSocketReadyState.CLOSING) {
      return SockratesResult.Failure(ConnectFailureReason.SocketClosing(url))
    }

    this.wsListener = wsListener

    try {
      println("Sockrates: connecting..")
      wsSession = connectInternal()

      clearPingTimeout()
      if (shouldSetPingTimeout(config)) {
        setPingTimeout(config.pingTimeoutInMillis)
      }

      wsSession?.let { session ->
        readMessagesJob = readIncomingMessages(session, wsListener)
      }

      println("Sockrates: connected to ${wsSession?.call}")

      runOnMain { this.wsListener?.onOpen(this) }
      setAndEmitConnectionState(WebSocketConnectionState.Connected)
      return SockratesResult.Success(Unit)
    } catch (e: CancellationException) {
      println("Sockrates: error while connecting: ${e.message}")
      throw e // this is a coroutine upstream exception
    } catch (e: KtorResponseException) {
      println("Sockrates: error while connecting: ${e.response.status.value}; ${e.message}")
      return SockratesResult.Failure(ConnectFailureReason.ConnectionError(url, e))
    } catch (e: URLParserException) {
      println("Sockrates: error while connecting: Invalid WS URL ${e.message}")
      return SockratesResult.Failure(ConnectFailureReason.InvalidWebSocketUrl(url))
    } catch (e: Exception) {
      println("Sockrates: error while connecting: " + e.message)
      return SockratesResult.Failure(ConnectFailureReason.ConnectionError(url, e))
    }
  }

  private suspend fun connectInternal(): DefaultClientWebSocketSession {
    readyState = WebSocketReadyState.CONNECTING
    try {
      val session = wsClient.webSocketSession(url)
      readyState = WebSocketReadyState.OPEN
      return session
    } catch (e: Exception) {
      readyState = WebSocketReadyState.CLOSED
      throw e
    }
  }

  private suspend fun readIncomingMessages(
    wsSession: DefaultClientWebSocketSession,
    wsListener: SockratesWSListener?
  ) = scope.launch {
    try {
      while (isActive) {
        val frame = wsSession.incoming.receive()
        println("Sockrates: Received $frame")
        try {
          when (frame) {
            is Frame.Binary -> {
              handleBinaryFrame(frame, wsListener)
            }

            is Frame.Text -> {
              handleTextFrame(frame, wsListener)
            }

            is Frame.Close -> {
              handleCloseFrame(frame, wsListener)
              break
            }

            else -> {
              println("Sockrates: other frame received -> $frame")
            }
          }
        } catch (e: okio.IOException) {
          // Exception thrown while decoding a message. Notify the library user.
          println("Sockrates: error while decoding message -> ${e.message}")
          wsListener?.onError(this@Sockrates, e)
        }
      }
    } catch (e: Exception) {
      println("Sockrates: error while reading socket -> ${e.message}")
      onExceptionWhileReadingSocket(e, wsListener)
    }
  }

  suspend fun send(message: String) {
    wsSession?.send(Frame.Text(message))
  }

  suspend fun send(message: ByteArray) {
    wsSession?.send(Frame.Binary(fin = true, data = message))
  }

  suspend fun send(
    event: Int,
    messageId: String? = null,
    payload: ByteArray? = null
  ): SockratesResult<Unit, SendFailureReason> {
    if (!connected) {
      return SockratesResult.Failure(SendFailureReason.SocketNotConnected(event, messageId))
    }

    // Using Coroutine context from Sockrates's scope to offload serialization to a worker
    return withContext(scope.coroutineContext) {
      try {
        val socketMessage = SocketMessage(
          event = event,
          id = messageId,
          payload = payload?.toByteString(),
        )

        val binaryMessage = encode(socketMessage)
        send(binaryMessage)
        return@withContext SockratesResult.Success(Unit)
      } catch (e: Exception) {
        println("Sockrates: failed to send {event: $event, messageId: $messageId} -> ${e.message}")
        return@withContext SockratesResult.Failure(SendFailureReason.Other(event, messageId, e))
      }
    }
  }

  suspend fun requestResponse(
    event: Int,
    messageId: String? = null,
    payload: ByteArray? = null
  ): SockratesResult<SocketMessage, RequestResponseFailureReason> {
    if (!connected) {
      return SockratesResult.Failure(
        RequestResponseFailureReason.SocketNotConnected(event, messageId)
      )
    }

    val completableDeferred = CompletableDeferred<SocketMessage>(scope.coroutineContext.job)

    val requestId = messageId ?: "abcd" // TODO: replace abcd with a generate method
    val responseListener = object : SockratesEventListener {
      override fun onEvent(client: Sockrates?, event: Int, message: SocketMessage) {
        if (requestId == message.id) {
          unsubscribe(event, this)
          completableDeferred.complete(message)
        }
      }
    }

    subscribe(event, responseListener)

    send(event, requestId, payload)
    try {
      return withTimeout(config.responseTimeoutInMillis) {
        val response = completableDeferred.await()
        SockratesResult.Success(response)
      }
    } catch (e: TimeoutCancellationException) {
      println("Sockrates: requestResponse timeout after ${config.responseTimeoutInMillis}ms!!")
      return SockratesResult.Failure(
        RequestResponseFailureReason.ResponseTimeout(
          event,
          messageId,
          config.responseTimeoutInMillis
        )
      )
    } catch (e: Exception) {
      println("Sockrates: exception in requestResponse -> ${e.message}")
      return SockratesResult.Failure(
        RequestResponseFailureReason.Other(event, messageId, e)
      )
    } finally {
      unsubscribe(event, responseListener)
    }
  }

  fun subscribe(event: Int, eventListener: SockratesEventListener): Boolean {
    val existingEventListeners = activeEventListeners[event]
    if (existingEventListeners == null) {
      val newEventListenersSet = LinkedHashSet<SockratesEventListener>()
      newEventListenersSet.add(eventListener)
      activeEventListeners[event] = newEventListenersSet
      return true
    } else {
      return existingEventListeners.add(eventListener)
    }
  }

  fun unsubscribe(event: Int, eventListener: SockratesEventListener): Boolean {
    val existingEventListeners = activeEventListeners[event]
    return existingEventListeners?.remove(eventListener) ?: false
  }

  @Deprecated(
    message = "Use subscribe() to listen to a particular event or override the SockratesWSListener.onMessage() to listen to all incoming messages",
    level = HIDDEN
  )
  suspend fun receive(onReceive: (input: SocketMessage) -> Unit) {
    while (true) {
      val frame = wsSession?.incoming?.receive()

      if (frame is Frame.Binary) {
        val socketMessage = Proto.decode(frame.data)
        onReceive(socketMessage)
      }
    }
  }

  suspend fun disconnect() {
    if (readyState == WebSocketReadyState.CLOSED && !reconnecting) return

    if (readyState == WebSocketReadyState.CLOSING) return

    readyState = WebSocketReadyState.CLOSING
    println("Sockrates: Closing websocket session...")
    scope.coroutineContext.cancelChildren()
    val closeReason = CloseReason(CloseReason.Codes.NORMAL, "Connection closed by client")
    wsSession?.close(closeReason)

    /**
     * note: wsClient.close() closes the Ktor engine. In case of Sockrates it is Websockets plugin
     * which we install while initialising the wsClient.
     * After closing, we need to install the plugin again in order to establish a WS session.
     * That's why no need to close it here.
     * A good place to call this would be when we want to dispose the Sockrates client.
     */
    // wsClient.close()

    readyState = WebSocketReadyState.CLOSED
    println("Sockrates: Websocket session closed")
    runOnMain { wsListener?.onClose(this, closeReason.code.toInt(), closeReason.message) }
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        closeReason.code.toInt(),
        closeReason.message
      )
    )
    clear()
  }

  fun clear() {
    activeEventListeners.clear()
  }

  private suspend fun handleBinaryFrame(frame: Frame.Binary, wsListener: SockratesWSListener?) {
    val socketMessage = decode(frame.readBytes())
    runOnMain {
      wsListener?.onMessage(this@Sockrates, WebSocketMessage.Binary(frame.readBytes()))
      wsListener?.onMessage(this@Sockrates, socketMessage)
      notifyEventSubscribers(socketMessage)
    }
  }

  private suspend fun handleTextFrame(frame: Frame.Text, wsListener: SockratesWSListener?) {
    if (frame.isCustomPing()) {
      println("Sockrates: Received custom ping")
      handleCustomPing()
    } else {
      runOnMain {
        wsListener?.onMessage(this@Sockrates, WebSocketMessage.Text(frame.readText()))
      }
    }
  }

  /*
  * Note: The doc for Close frame states the following:
  * "Usually there is no need to send/handle it unless you have a RAW web socket session."
  *
  * So we are just notifying the library users that the connection has been closed.
  *
  * We will need to test with a dev server if we need to explicitly cancel/close wsSession.
  * */
  private suspend fun handleCloseFrame(frame: Frame.Close, wsListener: SockratesWSListener?) {
    /*wsSession?.let {
      if (it.isActive) {
        it.cancel(CancellationException("Received Close frame from server. Closing WS session."))
      }
    }*/

    readyState = WebSocketReadyState.CLOSED
    val closeReason =
      frame.readReason() ?: CloseReason(CloseReason.Codes.NORMAL, "Connection closed by server.")
    runOnMain { wsListener?.onClose(this, closeReason.code.toInt(), closeReason.message) }
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        closeReason.code.toInt(),
        closeReason.message
      )
    )
  }

  private suspend fun handleCustomPing() {
    clearPingTimeout()
    sendCustomPongMessage()
    if (shouldSetPingTimeout(config)) {
      setPingTimeout(config.pingTimeoutInMillis)
    }
  }

  private fun shouldSetPingTimeout(config: SockratesConfiguration): Boolean {
    return config.disconnectOnPingTimeout
  }

  private fun setPingTimeout(timeoutInMillis: Long) {
    try {
      pingTimeoutJob = setTimeout(timeoutInMillis) {
        reconnectionJob = scope.launch {
          disconnectOnPingTimeout()
          backoff.reset() // Resetting the backoff attempts before starting reconnection flow
          delay(backoff.nextBackoffMillis()) // 1st attempt
          tryReconnection()
        }
      }
    } catch (e: Exception) {
      // Unknown internal exception in onTimeout() block.
      // TODO: Add recovery mechanism
      println("Sockrates: Ping timeout exception -> ${e.message}")
    }
  }

  private suspend fun disconnectOnPingTimeout() {
    if (readyState == WebSocketReadyState.CLOSED || readyState == WebSocketReadyState.CLOSING) return

    readyState = WebSocketReadyState.CLOSING
    println("Sockrates: Closing websocket session on Ping timeout...")

    readMessagesJob?.cancel()
    wsSession?.close(CloseReason(CLOSE_CODE_PING_TIMEOUT, CLOSE_REASON_PING_TIMEOUT))

    readyState = WebSocketReadyState.CLOSED
    println("Sockrates: Websocket session closed on Ping timeout!")
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        CLOSE_CODE_PING_TIMEOUT.toInt(),
        CLOSE_REASON_PING_TIMEOUT
      )
    )
  }

  private suspend fun tryReconnection() {
    if (connected || reconnecting) {
      return
    }

    reconnecting = true
    setAndEmitConnectionState(WebSocketConnectionState.Reconnecting)

    while (shouldTryReconnection()) {
      try {
        wsSession = reconnectInternal()
      } catch (e: CancellationException) {
        throw e
      } catch (e: Exception) {
        delay(backoff.nextBackoffMillis())
      }
    }

    reconnecting = false
    if (connected) {
      backoff.reset()

      clearPingTimeout()
      if (shouldSetPingTimeout(config)) {
        setPingTimeout(config.pingTimeoutInMillis)
      }

      wsSession?.let { session ->
        readMessagesJob = readIncomingMessages(session, wsListener)
      }

      setAndEmitConnectionState(WebSocketConnectionState.Reconnected)
    } else {
      setAndEmitConnectionState(WebSocketConnectionState.ReconnectFailed)
    }
  }

  private fun shouldTryReconnection(): Boolean {
    return (backoff.attemptsTillNow < config.maxReconnectionAttempts) && !connected
  }

  private suspend fun reconnectInternal(): DefaultClientWebSocketSession {
    try {
      setAndEmitConnectionState(WebSocketConnectionState.AttemptingReconnect(backoff.attemptsTillNow))
      return connectInternal()
    } catch (e: CancellationException) {
      throw e
    } catch (e: Exception) {
      println("Sockrates: reconnection exception -> ${e.message}")
      setAndEmitConnectionState(WebSocketConnectionState.ReconnectAttemptFailed(backoff.attemptsTillNow))
      throw e
    }
  }

  private fun clearPingTimeout() {
    pingTimeoutJob?.let {
      if (it.isActive) {
        it.cancel()
        println("Sockrates: Ping timeout cleared!")
      }
    }
  }

  private suspend fun sendCustomPongMessage() {
    send(CUSTOM_PONG_MESSAGE)
    println("Sockrates: Sent custom pong")
  }

  private fun notifyEventSubscribers(socketMessage: SocketMessage) {
    if (activeEventListeners.isEmpty()) return

    val eventListeners = activeEventListeners[socketMessage.event]
    if (!eventListeners.isNullOrEmpty()) {
      for (listener in eventListeners) {
        listener.onEvent(this, socketMessage.event, socketMessage)
      }
    }
  }

  private fun encode(socketMessage: SocketMessage) = Proto.encode(socketMessage)

  private fun decode(message: ByteArray) = Proto.decode(message)

  private suspend fun runOnMain(block: () -> Unit) {
    withContext(Dispatchers.Main) {
      block()
    }
  }

  @Throws(CancellationException::class)
  private fun setTimeout(timeoutInMillis: Long, onTimeout: () -> Unit): Job {
    return scope.launch {
      try {
        withTimeout(timeoutInMillis) {
          awaitCancellation()
        }
      } catch (e: TimeoutCancellationException) {
        println("Sockrates: Timeout in setTimeout()")
        onTimeout()
      } catch (e: CancellationException) {
        // The timeout job was cancelled
        println("Sockrates: Timeout cancelled!")
      } catch (e: Exception) {
        // Currently there is no need to process other exception that may happen in timeout logic
        println("Sockrates: setTimeout exception -> $e")
      }
    }
  }

  private suspend fun setAndEmitConnectionState(connectionState: WebSocketConnectionState) {
    with(this@Sockrates) {
      this.connectionState = connectionState
      runOnMain {
        wsListener?.onConnectionStateChanged(this, this.connectionState)
      }
    }
  }

  private suspend fun onExceptionWhileReadingSocket(
    exception: Exception,
    wsListener: SockratesWSListener?
  ) {
    // if we are already closing or have closed the connection then don't do anything
    if (readyState == WebSocketReadyState.CLOSING || readyState == WebSocketReadyState.CLOSED) {
      return
    }

    println("Sockrates: Detected disconnection! Trying to reconnect..")

    /*
    * Clearing the ping timeout to avoid additional reconnection flow
    * */
    clearPingTimeout()

    // our socket has disconnected due to an exception, can be a network issue, try reconnection
    val closeReason = CloseReason(
      CloseReason.Codes.GOING_AWAY,
      "Detected disconnection! Trying to reconnect..."
    )

    readyState = WebSocketReadyState.CLOSING

    wsSession?.close(closeReason)

    readyState = WebSocketReadyState.CLOSED

    runOnMain { wsListener?.onError(this, exception) }
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        closeReason.code.toInt(),
        closeReason.message
      )
    )

    if (config.autoReconnect) {
      tryReconnection()
    }
  }

  companion object {
    private const val CUSTOM_PING_MESSAGE = "2"
    private const val CUSTOM_PONG_MESSAGE = "3"

    private const val CLOSE_CODE_PING_TIMEOUT: Short = 3002
    private const val CLOSE_REASON_PING_TIMEOUT = "Ping timeout"

    private fun Frame.Text.isCustomPing(): Boolean {
      return readText() == CUSTOM_PING_MESSAGE
    }
  }
}