package io.dyte.sockrates.client

import io.dyte.sockrates.client.WebSocketReadyState.CLOSED
import io.dyte.sockrates.client.WebSocketReadyState.CLOSING
import io.dyte.sockrates.client.WebSocketReadyState.CONNECTING
import io.dyte.sockrates.client.WebSocketReadyState.OPEN
import io.ktor.client.*
import io.ktor.client.plugins.ResponseException
import io.ktor.client.plugins.websocket.*
import io.ktor.http.*
import io.ktor.websocket.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import message.v1.SocketMessage
import okio.ByteString.Companion.toByteString
import kotlin.DeprecationLevel.HIDDEN

/*
All the public methods of Sockrates are directly calling KTor's methods instead of
switching context (Dispatcher), as all the methods inside KTor are main-safe and internally
take care of switching the threads.
*/
class Sockrates(
  private val url: String,
  var config: SockratesConfig? = null
) {
  private val wsClient = HttpClient {
    install(WebSockets)
  }

  private var wsSession: DefaultClientWebSocketSession? = null

  var readyState: WebSocketReadyState = CONNECTING
    private set

  private val activeEventListeners: MutableMap<Int, MutableSet<SockratesEventListener>> = HashMap()

  private val scope: CoroutineScope = CoroutineScope(Dispatchers.Default + Job())

  suspend fun connect(webSocketListener: WebSocketListener? = null) {
    if (readyState == OPEN || readyState == CLOSING) return

    try {
      println("Sockrates: connecting..")
      val session = wsClient.webSocketSession(url)
      wsSession = session
      readyState = OPEN
      webSocketListener?.onOpen(this@Sockrates)

      listen(session, webSocketListener)

      println("Sockrates: connected to $url")
    } catch (e: CancellationException) {
      println("Sockrates: error while connecting: ${e.message}")
      throw e // this is a coroutine upstream exception
    } catch (e: ResponseException) {
      println("Sockrates: error while connecting: ${e.response.status.value}; ${e.message}")
      webSocketListener?.onError(this, "Invalid WS URL")
      return
    } catch (e: URLParserException) {
      println("Sockrates: error while connecting: Invalid WS URL ${e.message}")
      webSocketListener?.onError(this, "Invalid WS URL: $url")
      return
    } catch (e: Exception) {
      println("Sockrates: error while connecting: " + e.message)
      throw e
    }
  }

  private suspend fun listen(
    wsSession: DefaultClientWebSocketSession,
    webSocketListener: WebSocketListener?
  ) = scope.launch {
    try {
      while (isActive) {
        val frame = wsSession.incoming.receive() ?: continue
        println("Sockrates: Received $frame")
        when (frame) {
          is Frame.Binary -> {
            webSocketListener?.onMessage(this@Sockrates, frame.readBytes())
            val socketMessage = decode(frame.readBytes())
            webSocketListener?.onMessage(this@Sockrates, socketMessage)
            notifyEventSubscribers(socketMessage)
          }

          is Frame.Text -> {
            webSocketListener?.onMessage(this@Sockrates, frame.readText())
          }
          // Todo: Check if a call to disconnect is needed.
          is Frame.Close -> {
            readyState = CLOSED
            webSocketListener?.onClosed(this@Sockrates, "Websocket connection closed by server")
            break
          }

          else -> {
            println("Sockrates: other frame received -> $frame")
          }
        }
      }
    } catch (e: Exception) {
      println("Sockrates: exception while listening ${e.message}")
      disconnect()
      webSocketListener?.onClosed(this@Sockrates, "Connection closed. Unexpected EOF from server")
    }
  }

  private fun notifyEventSubscribers(socketMessage: SocketMessage) {
    if (activeEventListeners.isEmpty()) return

    val eventListeners = activeEventListeners[socketMessage.event]
    if (eventListeners != null && eventListeners.isNotEmpty()) {
      for (listener in eventListeners) {
        listener.onEvent(this, socketMessage.event, socketMessage)
      }
    }
  }

  suspend fun send(message: String) {
    wsSession?.send(Frame.Text(message))
  }

  suspend fun send(message: ByteArray) {
    wsSession?.send(Frame.Binary(fin = true, data = message))
  }

  suspend fun send(event: Int, messageId: String? = null, payload: ByteArray? = null) {
    val socketMessage = SocketMessage(
      event = event,
      id = messageId,
      payload = payload?.toByteString(),
    )

    val binaryMessage = encode(socketMessage)
    send(binaryMessage)
  }

  suspend fun requestResponse(
    event: Int,
    messageId: String? = null,
    payload: ByteArray? = null
  ): SocketMessage? {
    val completableDeferred = CompletableDeferred<SocketMessage>(scope.coroutineContext.job)

    val requestId = messageId ?: "abcd" // TODO: replace abcd with a generate method
    val responseListener = object : SockratesEventListener {
      override fun onEvent(client: Sockrates?, event: Int, message: SocketMessage) {
        if (requestId == message.id) {
          unsubscribe(event, this)
          completableDeferred.complete(message)
        }
      }
    }

    subscribe(event, responseListener)

    send(event, requestId, payload)
    try {
      return withTimeout(REQUEST_RESPONSE_TIMEOUT_IN_MILLIS) {
        val response = completableDeferred.await()
        response
      }
    } catch (e: TimeoutCancellationException) {
      println("Sockrates: requestResponse timeout!! returning null")
    } catch (e: Exception) {
      println("Sockrates: exception in requestResponse ${e.message}")
    }

    unsubscribe(event, responseListener)
    return null
  }

  fun subscribe(event: Int, eventListener: SockratesEventListener) {
    val existingEventListeners = activeEventListeners[event]
    if (existingEventListeners == null) {
      val newEventListenersSet = LinkedHashSet<SockratesEventListener>()
      newEventListenersSet.add(eventListener)
      activeEventListeners[event] = newEventListenersSet
    } else {
      existingEventListeners.add(eventListener)
    }
  }

  fun unsubscribe(event: Int, eventListener: SockratesEventListener) {
    val existingEventListeners = activeEventListeners[event]
    if (existingEventListeners != null) {
      existingEventListeners.remove(eventListener)
    }
  }

  @Deprecated(
    message = "Use subscribe() to listen to a particular event or override the WebsocketListener.onMessage() to listen to all incoming messages",
    level = HIDDEN
  )
  suspend fun receive(onReceive: (input: SocketMessage) -> Unit) {
    while (true) {
      val frame = wsSession?.incoming?.receive()

      if (frame is Frame.Binary) {
        val socketMessage = Proto.decode(frame.data)
        onReceive(socketMessage)
      }
    }
  }

  suspend fun disconnect() {
    if (readyState == CLOSED) return

    readyState = CLOSING
    println("Sockrates: Closing websocket session...")
    scope.cancel()
    wsSession?.close(CloseReason(CloseReason.Codes.NORMAL, "Connection closed by client"))
    wsClient.close()
    readyState = CLOSED
    clear()
    println("Sockrates: Websocket session closed")
  }

  fun clear() {
    activeEventListeners.clear()
  }

  private fun encode(socketMessage: SocketMessage) = Proto.encode(socketMessage)

  private fun decode(message: ByteArray) = Proto.decode(message)

  companion object {
    private const val REQUEST_RESPONSE_TIMEOUT_IN_MILLIS: Long = 10000L
  }
}