package io.privy.auth.jwt

import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import me.tatarka.inject.annotations.Inject

public interface DecodeJwt {
  public operator fun invoke(jwtToken: String): Jwt
}

@OptIn(ExperimentalEncodingApi::class)
@Inject
public class RealDecodeJwt : DecodeJwt {
  private val json by lazy { Json { ignoreUnknownKeys = true } }

  public override operator fun invoke(jwtToken: String): Jwt {
    val parts = jwtToken.split(".")

    require(parts.size == 3) { "Invalid JWT format: expected 3 parts" }

    val header = decodeJWTPart(parts[0])
    val claims = decodeJWTPart(parts[1])
    val signature = parts[2]

    return Jwt(header = header, claims = claims, signature = signature)
  }

  private fun decodeJWTPart(base64UrlPart: String): Map<String, JsonPrimitive> {
    val base64Part =
        base64UrlPart
            .replace('-', '+')
            .replace('_', '/')
            .padEnd(base64UrlPart.length + (4 - base64UrlPart.length % 4) % 4, '=')

    // Base 64 decode then convert to string
    val jsonBytes = Base64.Default.decode(base64Part)
    val jsonString = jsonBytes.decodeToString()

    val jsonElement = json.parseToJsonElement(jsonString)

    return if (jsonElement is JsonObject) {
      return mutableMapOf<String, JsonPrimitive>().apply {
        jsonElement.forEach { (key, value) ->
          if (value is JsonPrimitive) {
            // Grab each primitive type of the JsonObject and add it to the map
            this[key] = value
          }
        }
      }
    } else {
      emptyMap()
    }
  }
}
