package io.dyte.core.socket.socketservice

import io.dyte.core.observability.DyteLogger
import io.dyte.sockrates.client.*
import io.dyte.sockrates.client.SockratesResult.Failure
import io.dyte.sockrates.client.SockratesResult.Success
import io.dyte.sockrates.client.logger.ExternalLogger
import kotlin.coroutines.cancellation.CancellationException
import kotlin.random.Random
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import message.v1.SocketMessage

internal interface ISockratesSocketService {

  suspend fun connect()

  suspend fun disconnect()

  suspend fun send(event: Int, payload: ByteArray?, messageId: String? = null)

  @Throws(ResponseTimeoutException::class, CancellationException::class)
  suspend fun requestResponse(
    event: Int,
    payload: ByteArray?,
    messageId: String? = null,
  ): ByteArray?

  fun subscribe(event: Int, listener: SocketServiceEventListener)

  fun unsubscribe(event: Int, listener: SocketServiceEventListener)

  fun addConnectionStateListener(listener: SocketServiceConnectionStateListener)

  fun removeConnectionStateListener(listener: SocketServiceConnectionStateListener)

  fun clear()

  /**
   * Generates a new socket-service URL with the provided peerId.
   *
   * **NOTE**: This is needed as a part of a workaround that we have done to make the socket-service
   * work properly after the reconnection.
   */
  fun refreshUrl(peerId: String)
}

internal interface SocketServiceEventListener {
  fun onEvent(event: Int, eventId: String?, payload: ByteArray?)
}

internal interface SocketServiceConnectionStateListener {
  fun onConnectionStateChanged(newState: WebSocketConnectionState)
}

internal class ResponseTimeoutException(override val message: String) : RuntimeException(message)

internal class SockratesSocketService(
  baseDomain: String,
  private var peerId: String,
  private val roomName: String,
  private val authToken: String,
  private val useHive: Boolean,
  private val workContext: CoroutineDispatcher = Dispatchers.Default,
) : ISockratesSocketService {
  private var url: String
  private var wsClient: Sockrates
  private val socketEdgeDomain: String

  private object Logger : ExternalLogger {
    override fun debug(message: String) {
      DyteLogger.debug(message)
    }

    override fun error(message: String) {
      DyteLogger.error(message)
    }

    override fun info(message: String) {
      DyteLogger.info(message)
    }

    override fun warn(message: String) {
      DyteLogger.warn(message)
    }
  }

  private val activeEventListeners: MutableMap<Int, MutableSet<SocketServiceEventListener>> =
    HashMap()

  private val connectionStateListeners: MutableSet<SocketServiceConnectionStateListener> =
    LinkedHashSet()

  init {
    socketEdgeDomain = getSocketEdgeDomain(baseDomain)
    url = createSocketServiceUrl(socketEdgeDomain, peerId, roomName, authToken, useHive, true)
    wsClient =
      Sockrates(
        url,
        Logger,
        config =
          SockratesConfiguration(
            disconnectOnPingTimeout = true,
            autoReconnect = false,
            retryConnectionOnFailure = true,
          ),
      )
  }

  override suspend fun connect() {
    DyteLogger.debug("SockratesSocketService::connect::")
    withContext(workContext) {
      val result =
        wsClient.connect(
          object : SockratesWSListener {
            override fun onConnectionStateChanged(
              client: Sockrates,
              newState: WebSocketConnectionState,
            ) {
              super.onConnectionStateChanged(client, newState)
              try {
                notifyConnectionStateListeners(newState)
              } catch (e: IllegalArgumentException) {
                // no-op
              }
            }

            override fun onMessage(client: Sockrates, message: SocketMessage) {
              super.onMessage(client, message)
              notifyEventSubscribers(message)
            }
          }
        )

      when (result) {
        is Failure -> {
          DyteLogger.error("SocketService::connect::failed ${result.reason}")
        }
        is Success -> {
          DyteLogger.info("SocketService::connect::success")
        }
      }
    }
  }

  override suspend fun disconnect() {
    DyteLogger.info("SockratesSocketService::disconnect::")
    withContext(workContext) { wsClient.disconnect() }
  }

  override fun clear() {
    DyteLogger.info("SockratesSocketService::clear::")
    activeEventListeners.clear()
    connectionStateListeners.clear()
    wsClient.clear()
  }

  override suspend fun send(event: Int, payload: ByteArray?, messageId: String?) {
    withContext(workContext) {
      val result =
        wsClient.send(
          event = event,
          messageId = messageId ?: generateMessageId(peerId),
          payload = payload,
        )

      when (result) {
        is Failure -> {
          DyteLogger.error("SocketService::send event $event::failed ${result.reason}")
        }
        is Success -> {
          // no-op
        }
      }
    }
  }

  override suspend fun requestResponse(
    event: Int,
    payload: ByteArray?,
    messageId: String?,
  ): ByteArray? {
    return withContext(workContext) {
      val result =
        wsClient.requestResponse(
          event = event,
          messageId = messageId ?: generateMessageId(peerId),
          payload = payload,
        )
      when (result) {
        is Success -> {
          return@withContext result.value.payload?.toByteArray()
        }
        is Failure -> {
          val reason = result.reason
          DyteLogger.error("SocketService::requestResponse event $event::failure $reason")
          when (reason) {
            is RequestResponseFailureReason.ResponseTimeout -> {
              throw ResponseTimeoutException(
                "SocketService response timeout after ${reason.timeoutInMillis}ms"
              )
            }
            is RequestResponseFailureReason.SocketNotConnected -> {
              return@withContext null
            }
            is RequestResponseFailureReason.Other -> {
              return@withContext null
            }
          }
        }
      }
    }
  }

  override fun subscribe(event: Int, listener: SocketServiceEventListener) {
    val existingEventListeners = activeEventListeners[event]
    if (existingEventListeners == null) {
      val newEventListenersSet = LinkedHashSet<SocketServiceEventListener>()
      newEventListenersSet.add(listener)
      activeEventListeners[event] = newEventListenersSet
    } else {
      existingEventListeners.add(listener)
    }
  }

  override fun unsubscribe(event: Int, listener: SocketServiceEventListener) {
    activeEventListeners[event]?.remove(listener)
  }

  override fun addConnectionStateListener(listener: SocketServiceConnectionStateListener) {
    connectionStateListeners.add(listener)
  }

  override fun removeConnectionStateListener(listener: SocketServiceConnectionStateListener) {
    connectionStateListeners.remove(listener)
  }

  override fun refreshUrl(peerId: String) {
    DyteLogger.info("SockratesSocketService::refreshUrl::$peerId")
    this.peerId = peerId
    url = createSocketServiceUrl(socketEdgeDomain, peerId, roomName, authToken, useHive, true)
    wsClient =
      Sockrates(
        url,
        Logger,
        config =
          SockratesConfiguration(
            disconnectOnPingTimeout = true,
            autoReconnect = false,
            retryConnectionOnFailure = true,
          ),
      )
  }

  private fun notifyEventSubscribers(socketMessage: SocketMessage) {
    if (activeEventListeners.isEmpty()) return

    val eventListeners = activeEventListeners[socketMessage.event]
    if (!eventListeners.isNullOrEmpty()) {
      for (listener in eventListeners) {
        listener.onEvent(
          socketMessage.event,
          socketMessage.id,
          socketMessage.payload?.toByteArray(),
        )
      }
    }
  }

  private fun notifyConnectionStateListeners(connectionState: WebSocketConnectionState) {
    connectionStateListeners.forEach { it.onConnectionStateChanged(connectionState) }
  }

  companion object {
    private fun createSocketServiceUrl(
      socketEdgeDomain: String,
      peerId: String,
      roomName: String,
      authToken: String,
      useHive: Boolean,
      shouldPingPong: Boolean,
    ): String {
      return "wss://${socketEdgeDomain}/ws?roomID=${roomName}&peerID=${peerId}&authToken=${authToken}&useMediaV2=${useHive}&ping=${shouldPingPong}"
    }

    private fun generateMessageId(peerId: String): String {
      val randomSuffix = generateRandomString(5)
      return "${peerId}-${randomSuffix}"
    }

    private val characterSet =
      charArrayOf(
        '0',
        '1',
        '2',
        '3',
        '4',
        '5',
        '6',
        '7',
        '8',
        '9',
        'a',
        'b',
        'c',
        'd',
        'e',
        'f',
        'g',
        'h',
        'i',
        'j',
        'k',
        'l',
        'm',
        'n',
        'o',
        'p',
        'q',
        'r',
        's',
        't',
        'u',
        'v',
        'w',
        'x',
        'y',
        'z',
      )

    /*
    This is not finalised yet. Currently using the most straightforward method to
    generate a random 5 character alphanumeric string.
    Will need to come-up with a better efficient method or implement Base36 encoding in
    Kotlin.
     */
    private fun generateRandomString(idLength: Int = 5): String {
      val idBuilder = StringBuilder()
      var randomIndex: Int
      for (i in 0 until idLength) {
        randomIndex = (Random.nextInt(until = characterSet.size))
        idBuilder.append(characterSet[randomIndex])
      }

      return idBuilder.toString()
    }

    private fun getSocketEdgeDomain(baseDomain: String): String {
      // For both Dyte and non-Dyte domains we just need to append baseDomain
      return "socket-edge.${baseDomain}"
    }
  }
}
