package io.dyte.sockrates.client

import io.dyte.sockrates.client.WebSocketReadyState.CLOSED
import io.dyte.sockrates.client.WebSocketReadyState.CLOSING
import io.dyte.sockrates.client.WebSocketReadyState.CONNECTING
import io.dyte.sockrates.client.WebSocketReadyState.OPEN
import io.dyte.sockrates.client.backoff.Backoff
import io.dyte.sockrates.client.backoff.ExponentialBackoff
import io.ktor.client.*
import io.ktor.client.plugins.ResponseException as KtorResponseException
import io.ktor.client.plugins.websocket.*
import io.ktor.http.*
import io.ktor.websocket.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.awaitCancellation
import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import message.v1.SocketMessage
import okio.ByteString.Companion.toByteString
import kotlin.DeprecationLevel.HIDDEN

/*
All the public methods of Sockrates are directly calling KTor's methods instead of
switching context (Dispatcher), as all the methods inside KTor are main-safe and internally
take care of switching the threads.
*/
class Sockrates(
  private val url: String,
  private val config: SockratesConfiguration = SockratesConfiguration(),
  private val backoff: Backoff = ExponentialBackoff(addJitter = true)
) {
  private val wsClient = HttpClient {
    install(WebSockets)
  }

  private var wsListener: SockratesWSListener? = null

  private var wsSession: DefaultClientWebSocketSession? = null

  private var readyState: WebSocketReadyState = CONNECTING

  private var connectionState: WebSocketConnectionState = WebSocketConnectionState.Disconnected(
    CloseReason.Codes.NORMAL.code.toInt(),
    "Client has not started connection yet"
  )

  private var reconnecting: Boolean = false

  val connected: Boolean
    get() = readyState == OPEN

  private val activeEventListeners: MutableMap<Int, MutableSet<SockratesEventListener>> = HashMap()

  private val scope: CoroutineScope = CoroutineScope(Dispatchers.Default + SupervisorJob())

  private var readMessagesJob: Job? = null
  private var pingTimeoutJob: Job? = null
  private var reconnectionJob: Job? = null

  suspend fun connect(wsListener: SockratesWSListener? = null) {
    if (readyState == OPEN || readyState == CLOSING) return

    this.wsListener = wsListener

    try {
      println("Sockrates: connecting..")
      wsSession = connectInternal()

      wsListener?.onOpen(this@Sockrates)
      setAndEmitConnectionState(WebSocketConnectionState.Connected)

      clearPingTimeout()
      if (shouldSetPingTimeout(config)) {
        setPingTimeout(config.pingTimeoutInMillis)
      }

      wsSession?.let { session ->
        readMessagesJob = readIncomingMessages(session, wsListener)
        println("Sockrates: connected to ${session.call}")
      }
    } catch (e: CancellationException) {
      println("Sockrates: error while connecting: ${e.message}")
      throw e // this is a coroutine upstream exception
    } catch (e: KtorResponseException) {
      println("Sockrates: error while connecting: ${e.response.status.value}; ${e.message}")
      wsListener?.onFailure(
        this,
        ConnectException(e.message ?: "Could not connect to $url"),
        "${e.message} -> ${e.response.status.value}"
      )
    } catch (e: URLParserException) {
      println("Sockrates: error while connecting: Invalid WS URL ${e.message}")
      wsListener?.onFailure(
        this,
        UrlParseException(e.message ?: "Invalid websocket url: $url"),
        "Invalid websocket URL: $url"
      )
    } catch (e: Exception) {
      println("Sockrates: error while connecting: " + e.message)
      wsListener?.onFailure(
        this,
        ConnectException(e.message ?: "Could not connect to $url"),
        e.message ?: "Could not connect to $url"
      )
    }
  }

  private suspend fun connectInternal(): DefaultClientWebSocketSession {
    val session = wsClient.webSocketSession(url)
    readyState = OPEN
    return session
  }

  private suspend fun readIncomingMessages(
    wsSession: DefaultClientWebSocketSession,
    wsListener: SockratesWSListener?
  ) = scope.launch {
    try {
      while (isActive) {
        val frame = wsSession.incoming.receive() ?: continue
        println("Sockrates: Received $frame")
        when (frame) {
          is Frame.Binary -> {
            handleBinaryFrame(frame, wsListener)
          }

          is Frame.Text -> {
            handleTextFrame(frame, wsListener)
          }

          is Frame.Close -> {
            handleCloseFrame(frame, wsListener)
            break
          }

          else -> {
            println("Sockrates: other frame received -> $frame")
          }
        }
      }
    } catch (e: ClosedReceiveChannelException) {
      // The receive channel of Ktor was closed due to disconnection. Can be from either side.
      println("Sockrates: receive channel closed")
      // Changing Sockrates state if connection was closed by Server.
      readyState = CLOSED
      setAndEmitConnectionState(
        WebSocketConnectionState.Disconnected(
          CloseReason.Codes.INTERNAL_ERROR.code.toInt(),
          "Unexpected EOF from server"
        )
      )
    } catch (e: CancellationException) {
      // The listen job was cancelled
      println("Sockrates: receive message job cancelled")
    } catch (e: Exception) {
      /*
      * note: This happens when we encounter any error while listening to incoming messages.
      * TODO: Check if we need to call disconnect() here.
      * */
      println("Sockrates: exception while listening -> ${e.message}")
      // disconnect()
      // wsListener?.onClosed(this@Sockrates, 1, "Connection closed. Unexpected EOF from server")
    }
  }

  suspend fun send(message: String) {
    wsSession?.send(Frame.Text(message))
  }

  suspend fun send(message: ByteArray) {
    wsSession?.send(Frame.Binary(fin = true, data = message))
  }

  suspend fun send(event: Int, messageId: String? = null, payload: ByteArray? = null) {
    // Using Coroutine context from Sockrates's scope to offload serialization to a worker
    withContext(scope.coroutineContext) {
      val socketMessage = SocketMessage(
        event = event,
        id = messageId,
        payload = payload?.toByteString(),
      )

      val binaryMessage = encode(socketMessage)
      send(binaryMessage)
    }
  }

  @Throws(
    ResponseTimeoutException::class,
    ResponseException::class,
    kotlin.coroutines.cancellation.CancellationException::class
  )
  suspend fun requestResponse(
    event: Int,
    messageId: String? = null,
    payload: ByteArray? = null
  ): SocketMessage {
    val completableDeferred = CompletableDeferred<SocketMessage>(scope.coroutineContext.job)

    val requestId = messageId ?: "abcd" // TODO: replace abcd with a generate method
    val responseListener = object : SockratesEventListener {
      override fun onEvent(client: Sockrates?, event: Int, message: SocketMessage) {
        if (requestId == message.id) {
          unsubscribe(event, this)
          completableDeferred.complete(message)
        }
      }
    }

    subscribe(event, responseListener)

    send(event, requestId, payload)
    try {
      return withTimeout(config.responseTimeoutInMillis) {
        val response = completableDeferred.await()
        response
      }
    } catch (e: TimeoutCancellationException) {
      println("Sockrates: requestResponse timeout!! returning null")
      throw ResponseTimeoutException("Timed out waiting for response for ${config.responseTimeoutInMillis}ms")
    } catch (e: Exception) {
      println("Sockrates: exception in requestResponse -> ${e.message}")
      throw ResponseException(e.message ?: "Could not request a response for event: $event")
    } finally {
      unsubscribe(event, responseListener)
    }
  }

  fun subscribe(event: Int, eventListener: SockratesEventListener) {
    val existingEventListeners = activeEventListeners[event]
    if (existingEventListeners == null) {
      val newEventListenersSet = LinkedHashSet<SockratesEventListener>()
      newEventListenersSet.add(eventListener)
      activeEventListeners[event] = newEventListenersSet
    } else {
      existingEventListeners.add(eventListener)
    }
  }

  fun unsubscribe(event: Int, eventListener: SockratesEventListener) {
    val existingEventListeners = activeEventListeners[event]
    if (existingEventListeners != null) {
      existingEventListeners.remove(eventListener)
    }
  }

  @Deprecated(
    message = "Use subscribe() to listen to a particular event or override the SockratesWSListener.onMessage() to listen to all incoming messages",
    level = HIDDEN
  )
  suspend fun receive(onReceive: (input: SocketMessage) -> Unit) {
    while (true) {
      val frame = wsSession?.incoming?.receive()

      if (frame is Frame.Binary) {
        val socketMessage = Proto.decode(frame.data)
        onReceive(socketMessage)
      }
    }
  }

  suspend fun disconnect() {
    if (readyState == CLOSED && !reconnecting) return

    readyState = CLOSING
    println("Sockrates: Closing websocket session...")
    scope.coroutineContext.cancelChildren()
    val closeReason = CloseReason(CloseReason.Codes.NORMAL, "Connection closed by client")
    wsSession?.close(closeReason)

    /**
     * note: wsClient.close() closes the Ktor engine. In case of Sockrates it is Websockets plugin
     * which we install while initialising the wsClient.
     * After closing, we need to install the plugin again in order to establish a WS session.
     * That's why no need to close it here.
     * A good place to call this would be inside the clear() method.
     */
    // wsClient.close()
    readyState = CLOSED
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        closeReason.code.toInt(),
        closeReason.message
      )
    )
    clear()
    println("Sockrates: Websocket session closed")
  }

  fun clear() {
    activeEventListeners.clear()
  }

  private suspend fun handleBinaryFrame(frame: Frame.Binary, wsListener: SockratesWSListener?) {
    val socketMessage = decode(frame.readBytes())
    runOnMain {
      wsListener?.onMessage(this@Sockrates, WebSocketMessage.Binary(frame.readBytes()))
      wsListener?.onMessage(this@Sockrates, socketMessage)
      notifyEventSubscribers(socketMessage)
    }
  }

  private suspend fun handleTextFrame(frame: Frame.Text, wsListener: SockratesWSListener?) {
    if (frame.isCustomPing()) {
      println("Sockrates: Received custom ping")
      handleCustomPing()
    } else {
      runOnMain {
        wsListener?.onMessage(this@Sockrates, WebSocketMessage.Text(frame.readText()))
      }
    }
  }

  // Todo: Check if a call to disconnect is needed.
  private suspend fun handleCloseFrame(frame: Frame.Close, wsListener: SockratesWSListener?) {
    readyState = CLOSED
    val closeReason = frame.readReason() ?: CloseReason(1001, "Unexpected EOF from server")
    runOnMain {
      wsListener?.onClosed(this@Sockrates, closeReason.code.toInt(), closeReason.message)
    }
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        closeReason.code.toInt(),
        closeReason.message
      )
    )
  }

  private suspend fun handleCustomPing() {
    clearPingTimeout()
    sendCustomPongMessage()
    if (shouldSetPingTimeout(config)) {
      setPingTimeout(config.pingTimeoutInMillis)
    }
  }

  private fun shouldSetPingTimeout(config: SockratesConfiguration): Boolean {
    return config.disconnectOnPingTimeout
  }

  private fun setPingTimeout(timeoutInMillis: Long) {
    try {
      pingTimeoutJob = setTimeout(timeoutInMillis) {
        reconnectionJob = scope.launch {
          disconnectOnPingTimeout()
          backoff.reset() // Resetting the backoff attempts before starting reconnection flow
          delay(backoff.nextBackoffMillis()) // 1st attempt
          reconnect()
        }
      }
    } catch (e: Exception) {
      // Unknown internal exception in onTimeout() block
      println("Sockrates: Ping timeout exception -> ${e.message}")
    }
  }

  private suspend fun disconnectOnPingTimeout() {
    if (readyState == CLOSED) return

    readyState = CLOSING
    println("Sockrates: Closing websocket session on Ping timeout...")

    readMessagesJob?.cancel()
    wsSession?.close(CloseReason(CLOSE_CODE_PING_TIMEOUT, CLOSE_REASON_PING_TIMEOUT))

    readyState = CLOSED
    println("Sockrates: Websocket session closed on Ping timeout!")
    setAndEmitConnectionState(
      WebSocketConnectionState.Disconnected(
        CLOSE_CODE_PING_TIMEOUT.toInt(),
        CLOSE_REASON_PING_TIMEOUT
      )
    )
  }

  private suspend fun reconnect() {
    if (connected || reconnecting) {
      return
    }

    reconnecting = true
    setAndEmitConnectionState(WebSocketConnectionState.Reconnecting)

    while (shouldTryReconnection()) {
      try {
        reconnectInternal()
      } catch (e: CancellationException) {
        throw e
      } catch (e: Exception) {
        delay(backoff.nextBackoffMillis())
      }
    }

    reconnecting = false
    if (connected) {
      backoff.reset()
      setAndEmitConnectionState(WebSocketConnectionState.Reconnected)
    } else {
      setAndEmitConnectionState(WebSocketConnectionState.ReconnectFailed)
    }
  }

  private fun shouldTryReconnection(): Boolean {
    return (backoff.attemptsTillNow < config.maxReconnectionAttempts) && !connected
  }

  private suspend fun reconnectInternal() {
    try {
      setAndEmitConnectionState(WebSocketConnectionState.AttemptingReconnect(backoff.attemptsTillNow))

      wsSession = connectInternal()

      clearPingTimeout()
      if (shouldSetPingTimeout(config)) {
        setPingTimeout(config.pingTimeoutInMillis)
      }

      wsSession?.let { session ->
        readMessagesJob = readIncomingMessages(session, wsListener)
      }
    } catch (e: CancellationException) {
      throw e
    } catch (e: Exception) {
      println("Sockrates: reconnection exception -> ${e.message}")
      setAndEmitConnectionState(WebSocketConnectionState.ReconnectAttemptFailed(backoff.attemptsTillNow))
      throw e
    }
  }

  private fun clearPingTimeout() {
    pingTimeoutJob?.let {
      if (it.isActive) {
        it.cancel()
        println("Sockrates: Ping timeout cleared!")
      }
    }
  }

  private suspend fun sendCustomPongMessage() {
    send(CUSTOM_PONG_MESSAGE)
    println("Sockrates: Sent custom pong")
  }

  private fun notifyEventSubscribers(socketMessage: SocketMessage) {
    if (activeEventListeners.isEmpty()) return

    val eventListeners = activeEventListeners[socketMessage.event]
    if (eventListeners != null && eventListeners.isNotEmpty()) {
      for (listener in eventListeners) {
        listener.onEvent(this, socketMessage.event, socketMessage)
      }
    }
  }

  private fun encode(socketMessage: SocketMessage) = Proto.encode(socketMessage)

  private fun decode(message: ByteArray) = Proto.decode(message)

  private suspend fun runOnMain(block: () -> Unit) {
    withContext(Dispatchers.Main) {
      block()
    }
  }

  @Throws(CancellationException::class)
  private fun setTimeout(timeoutInMillis: Long, onTimeout: () -> Unit): Job {
    return scope.launch {
      try {
        withTimeout(timeoutInMillis) {
          awaitCancellation()
        }
      } catch (e: TimeoutCancellationException) {
        println("Sockrates: Timeout in setTimeout()")
        onTimeout()
      } catch (e: CancellationException) {
        // The timeout job was cancelled
        println("Sockrates: Timeout cancelled!")
      } catch (e: Exception) {
        // Currently there is no need to process other exception that may happen in timeout logic
        println("Sockrates: setTimeout exception -> $e")
      }
    }
  }

  private suspend fun setAndEmitConnectionState(connectionState: WebSocketConnectionState) {
    with(this@Sockrates) {
      this.connectionState = connectionState
      runOnMain {
        wsListener?.onConnectionStateChanged(this, this.connectionState)
      }
    }
  }

  companion object {
    private const val CUSTOM_PING_MESSAGE = "2"
    private const val CUSTOM_PONG_MESSAGE = "3"

    private const val CLOSE_CODE_PING_TIMEOUT: Short = 3002
    private const val CLOSE_REASON_PING_TIMEOUT = "Ping timeout"

    private fun Frame.Text.isCustomPing(): Boolean {
      return readText() == CUSTOM_PING_MESSAGE
    }
  }
}