﻿package pragma.protoc.plugin.custom
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.github.mustachejava.DefaultMustacheFactory
import com.google.protobuf.DescriptorProtos
import com.google.protobuf.compiler.PluginProtos
import java.io.StringWriter
import java.nio.file.Path
import pragma.PragmaOptions

/**
 * Returns the leaf node of the package. e.g. pragma.account = account.
 * If it's at the pragma root (pragma) it will return empty string.
 */
fun namespaceFromPackage(file: DescriptorProtos.FileDescriptorProto): String {
    val leaf = file.`package`.substring(file.`package`.lastIndexOf('.') + 1).lowercase()
    return if (leaf != "pragma") {
        leaf.replaceFirstChar { it.uppercase() }
    } else {
        ""
    }
}

/**
 * Returns the backend namespace in the /v1/types definition. e.g. accountRpc.proto = AccountRpc
 */
fun backendNamespace(filename: String) = filename.substringAfterLast('/').replaceFirstChar { it.uppercase() }.removeSuffix(".proto")
fun backendNamespace(file: DescriptorProtos.FileDescriptorProto) = backendNamespace(file.name)

/**
 * Returns the backend type name used in /v1/types e.g. AccountRpc.proto, LoginV1Request = AccountRpc.LoginV1Request
 */
fun backendTypeName(file: DescriptorProtos.FileDescriptorProto, name: String) = "${backendNamespace(file)}.$name"

/**
 * Combines path and message name into a fully qualified path.
 * e.g.
 * packagePath: pragma.account
 * messageName: LoginV1Request
 * result: .pragma.account.LoginV1Request
 */
fun fullyQualifiedTypeName(packagePath: String, messageName: String) =
    // Note the '.' at the beginning is a quirk of how proto plugin processes the package name.
    if (packagePath.startsWith('.')) "$packagePath.$messageName" else ".$packagePath.$messageName"

fun String.snakeToPascal(): String {
    val regex = "_.".toRegex()
    return regex.replace(this) {
        it.value
            .replace("_", "")
            .uppercase()
    }.replaceFirstChar { it.uppercase() }
}

fun String.snakeToKebab(): String {
    val regex = "_.".toRegex()
    return regex.replace(this) {
        it.value
            .replace("_", "-")
    }
}

fun String.pascalToKebab(): String {
    val regex = "(?<=[a-zA-Z])[A-Z]".toRegex()
    return regex.replace(this) {
        "-${it.value}"
    }.lowercase()
}

fun DescriptorProtos.DescriptorProto.optionalOneOfIndexes(): List<Int> = this.fieldList.filter { it.proto3Optional }.map{ it.oneofIndex }

fun DescriptorProtos.FieldDescriptorProto.isFixed128(): Boolean = this.typeName.contains("Fixed128", ignoreCase = true)

abstract class SdkGenerator : ProtocPlugin() {
    /**
     * The following types are used to fill out context maps for compiling templates.
     */
    protected data class BackendTypeTemplateContext(val backendName: String, val nativeType: String)
    data class MessageFieldTemplateContext(val nativeFieldType: String, val name: String, val defaultValue: String, val annotations: String, val isMap: Boolean)
    open class MessageTemplateContext(val nativeType: String, val fields: List<MessageFieldTemplateContext>, val oneOfTypes: List<String>)
    data class OneOfFieldTemplateContext(val fieldType: String, val fieldName: String, val fieldNameLower: String, val fieldJsonType: String, val isStructOrEnum: Boolean, val isPrimitive: Boolean)
    open class OneOfTemplateContext(val oneOfType: String, val oneOfEnum: String, val fields: List<OneOfFieldTemplateContext>, val oneOfParentType: String, val oneOfParentFieldName: String, val isOptional: Boolean)
    data class EnumValueTemplateContext(val name: String, val value: Int)
    open class EnumTemplateContext(val nativeType: String, val enumValues: List<EnumValueTemplateContext>)

    protected enum class ProtobufType {
        Enum,
        Message,
    }

    /**
     * Returns a template compiled and filled out with the specified context.
     */
    fun compileTemplate(templateFilepath: String, context: Any): String {
        val mustache = DefaultMustacheFactory().compile(templateFilepath)
        val writer = StringWriter()
        mustache.execute(writer, context)
        return writer.toString()
    }

    /**
     * Is this message a request, response, or notification API type?
     *
     * Note that due to proto-filter plugin, everything parsed by this plugin will be pragma_visibility_public,
     * but will not necessarily have the option specified.
     */
    protected fun isSdkApiType(message: DescriptorProtos.DescriptorProto): Boolean {

        if (!hasExternalVisibility(message)) return false

        return when (message.options.getExtension(PragmaOptions.pragmaMessageType)) {
            PragmaOptions.PragmaMessageType.REQUEST,
            PragmaOptions.PragmaMessageType.RESPONSE,
            PragmaOptions.PragmaMessageType.NOTIFICATION -> true
            PragmaOptions.PragmaMessageType.INTERNAL,
            PragmaOptions.PragmaMessageType.UNRECOGNIZED -> false
            else -> error("SdkGenerator.isSdkApiType -- null PragmaOptions type. This shouldn't happen because proto should default the field.")
        }
    }

    protected fun hasExternalVisibility(message: DescriptorProtos.DescriptorProto): Boolean {
        return (message.options.getExtension(PragmaOptions.externalVisibility) == PragmaOptions.ExternalVisibility.SHOW)
    }

    protected fun filterSdkApiFiles(request: PluginProtos.CodeGeneratorRequest): List<DescriptorProtos.FileDescriptorProto> =
        request.protoFileList.filter { requestFile ->
            requestFile.messageTypeList.any { isSdkApiType(it) }
        }

    open fun outputFilename(requestFilename: String): String =
        "Pragma${backendNamespace(requestFilename).replace("ServicePB", "Service").replace("Pragma", "")}Dto"
}

data class GeneratedDependencyInfoTemplateContext(val filename: String)
typealias GeneratedDependencyInfoTemplateContextMap = Map<String, GeneratedDependencyInfoTemplateContext>

data class DependencyGraphEntry(val fullyQualifiedTypeName: String, val message: DescriptorProtos.DescriptorProto) {
    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        if (javaClass != other?.javaClass) return false

        other as DependencyGraphEntry

        if (fullyQualifiedTypeName != other.fullyQualifiedTypeName) return false

        return true
    }

    override fun hashCode(): Int {
        return fullyQualifiedTypeName.hashCode()
    }
}
typealias DependencyGraph = DirectedUniqueGraph<DependencyGraphEntry>

const val noCodegenServicesArgPrefix = "--noCodegenServices="

abstract class MustacheSdkGenerator(
    args: Array<String> = arrayOf(),
    private val fileWrapper: PragmaProtoFileWrapper = PragmaProtoFileWrapper()
) : SdkGenerator() {

    private val noCodegenServices = args.filter { it.startsWith(noCodegenServicesArgPrefix) }.map { it.replace(noCodegenServicesArgPrefix, "").lowercase().split(",") }.flatten()
    private val previousProtosPaths = args.filter { !it.startsWith("--") }

    // Fully-qualified type name to native type.
    protected var qualifiedToNativeType: MutableMap<String, String> = mutableMapOf()
    private var previousProtos: MutableMap<String, String> = mutableMapOf()

    protected fun generateCustomNamespace(file: DescriptorProtos.FileDescriptorProto): String {
        val customNamespace = file.options.getExtension(PragmaOptions.unrealNamespace)
        return when {
            customNamespace == "null" -> ""
            customNamespace.isNotEmpty() -> customNamespace
            else -> namespaceFromPackage(file)
        }
    }

    protected fun packageRoot(pkg: String) = pkg.substringBefore('.').replaceFirstChar { it.uppercase() }
    protected fun isNoCodegenService(pkgRoot: String, serviceName: String) = noCodegenServices.contains("$pkgRoot$serviceName".lowercase())

    protected abstract fun createMessageTemplateContexts(namespace: String, message: DescriptorProtos.DescriptorProto): Pair<MessageTemplateContext, List<OneOfTemplateContext>>
    protected abstract fun createEnumTemplateContext(namespace: String, enum: DescriptorProtos.EnumDescriptorProto): EnumTemplateContext

    protected open fun createOneOfTemplateContexts(namespace: String, message: DescriptorProtos.DescriptorProto): List<OneOfTemplateContext> {
        return listOf()
    }

    /**
     * Returns a map of 'dependency' from the CodeGeneratorRequest.File.dependencyList to collected info about that dependency.
     */
    protected fun generateDependencyInfo(request: PluginProtos.CodeGeneratorRequest): GeneratedDependencyInfoTemplateContextMap {
        return request.protoFileList.associateBy(
            { it.name },
            {
                GeneratedDependencyInfoTemplateContext(
                    outputFilename(it.name)
                )
            }
        )
    }

    fun populatePreviousProtos() {
        previousProtosPaths.forEach {
            val fileContents = fileWrapper.read(it)
            val mapper = ObjectMapper()

            val type = mapper.typeFactory.constructMapType(Map::class.java, String::class.java, String::class.java)
            previousProtos.putAll(mapper.readValue(fileContents, type) as Map<String, String>)
        }
    }

    fun populateQualifiedToNativeTypeMap(files: List<DescriptorProtos.FileDescriptorProto>) {
        populatePreviousProtos()
        // Map fully qualified names to generated types so that we can refer to them correctly across messages.
        // This is only necessary because the native type requires a prefix of the type (e.g., E or F) that we cannot know
        // from the full typename alone.
        for (file in files) {
            val namespace = generateCustomNamespace(file)
            qualifiedToNativeType.putAll(
                file.enumTypeList.map { enum ->
                    fullyQualifiedTypeName(
                        file.`package`,
                        enum.name
                    ) to generatedType(namespace, enum)
                }
            )
            // Since maps can reference other types, we need to have them in our qualifiedToNativeType map before we
            // process the maps, so we collect messages in a two-pass system.
            collectMessageNativeTypes(file.messageTypeList, file.`package`, namespace, collectMaps = false)
            collectMessageNativeTypes(file.messageTypeList, file.`package`, namespace, collectMaps = true)
        }
    }

    protected fun getDependencySortedMessages(file: DescriptorProtos.FileDescriptorProto): List<DescriptorProtos.DescriptorProto> {
        // Some languages (e.g. C++) require that types are declared in the order they are used.
        val fullyQualifiedTypeNameToMessage =
            file.messageTypeList.associateBy { fullyQualifiedTypeName(file.`package`, it.name) }
        val graph = DependencyGraph()
        for (message in file.messageTypeList) {
            val messageEntry =
                DependencyGraphEntry(fullyQualifiedTypeName(file.`package`, message.name), message)
            graph.addVertex(messageEntry)
            for (nested in message.nestedTypeList) {
                if (!nested.options.hasMapEntry()) {
                    continue
                }
                val valueField = nested.fieldList.find { it.name == "value" } ?: continue
                val fieldTypeEntry = fieldDependency(valueField.typeName, fullyQualifiedTypeNameToMessage) ?: continue
                messageEntry.dependsOn(fieldTypeEntry, graph)
            }
            for (field in message.fieldList) {
                if (field.type != DescriptorProtos.FieldDescriptorProto.Type.TYPE_MESSAGE) {
                    continue
                }
                val fieldTypeEntry = fieldDependency(field.typeName, fullyQualifiedTypeNameToMessage) ?: continue
                messageEntry.dependsOn(fieldTypeEntry, graph)
            }
        }
        val dependencySortResult = graph.topologicalSort()
        if (dependencySortResult.isCyclical) {
            error(
                "Messages within a file cannot have cyclical dependencies on each other or themselves, " +
                    "as this is not supported in some languages (C++). Cycle found while parsing these messages:\n" +
                    dependencySortResult.values.joinToString("\n") { it.fullyQualifiedTypeName }
            )
        }
        return dependencySortResult.values.map { it.message }
    }

    private fun DependencyGraphEntry.dependsOn(other: DependencyGraphEntry, parent: DependencyGraph) {
        parent.addEdge(other, this)
    }

    private fun fieldDependency(typeName: String, messageLookup: Map<String, DescriptorProtos.DescriptorProto>): DependencyGraphEntry? {
        val localMessage = messageLookup[typeName] ?: return null
        return DependencyGraphEntry(typeName, localMessage)
    }

    protected fun qualifiedOneofName(namespace: String, messageName: String, oneofIndex: Int): String {
        return "${namespace}_${messageName}_oneof_${oneofIndex}"
    }

    /**
     * Collects message fully-qualified type name to native type pairs recursively through nested messages.
     */
    private fun collectMessageNativeTypes(
        messages: List<DescriptorProtos.DescriptorProto>,
        packagePath: String,
        namespace: String,
        collectMaps: Boolean
    ) {
        messages.forEach { message ->
            val qualifiedTypeName = fullyQualifiedTypeName(packagePath, message.name)
            if (message.options.mapEntry) {
                if (collectMaps) {
                    val keyType = message.fieldList.find { it.name == "key" }
                    val valueType = message.fieldList.find { it.name == "value" }
                    if (keyType == null || valueType == null) {
                        error("Map type $qualifiedTypeName is missing a key or value.")
                    }
                    try {
                        qualifiedToNativeType[qualifiedTypeName] =
                            generatedType(valueType, keyType)
                    } catch (ex: IllegalStateException) {
                        error("Unable to find an existing type for value type '$valueType' in map entry '$qualifiedTypeName'. Note that nested maps are currently unsupported.")
                    }
                }
            } else {
                qualifiedToNativeType[qualifiedTypeName] =
                    generatedType(namespace, message)
            }
            collectMessageNativeTypes(message.nestedTypeList, qualifiedTypeName, namespace, collectMaps)
        }
    }

    /** Return the enum type name */
    open fun getEnumName(type: DescriptorProtos.EnumDescriptorProto): String {
        return type.name
    }

    /** Generate a type name for an enum */
    fun generatedType(namespace: String, type: DescriptorProtos.EnumDescriptorProto): String {
        return generatedTypeInternal(ProtobufType.Enum, namespace, getEnumName(type))
    }

    /** Return the type name */
    open fun getTypeName(type: DescriptorProtos.DescriptorProto): String {
        return type.name
    }

    /** Generate a type name for a message */
    fun generatedType(namespace: String, type: DescriptorProtos.DescriptorProto): String {
        return generatedTypeInternal(ProtobufType.Message, namespace, getTypeName(type))
    }

    /** Generate a type name for a field or primitive type */
    abstract fun generatedType(field: DescriptorProtos.FieldDescriptorProto): String

    /** Generate a type name for a map */
    open fun generatedType(valueType: DescriptorProtos.FieldDescriptorProto?, keyType: DescriptorProtos.FieldDescriptorProto?) : String {
        return ""
    }

    /** Generate a type name for a primitive type */
    protected abstract fun generatedTypeInternal(protoType: ProtobufType, namespace: String, typeName: String): String

    fun buildMessageType(field: DescriptorProtos.FieldDescriptorProto): String {
        return qualifiedToNativeType[field.typeName] ?: previousProtos[field.typeName]
        ?: error("Unknown type '${field.typeName}' for field named '${field.name}' @ field number '${field.number}'")
    }

    open fun generatedTypeDefaultValue(field: DescriptorProtos.FieldDescriptorProto): String {
        return ""
    }

    protected fun jsonType(field: DescriptorProtos.FieldDescriptorProto): String {
        val fieldType = if (field.isFixed128()) {
            // Temp override until we handle Fixed128 type properly in Unreal.
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING
        } else {
            field.type
        }
        return when (fieldType) {
            // These map to the PragmaJson::JsonValueTo* and PragmaJson::*ToJsonValue functions.
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE -> "Double"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_FLOAT -> "Float"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT32 -> "Int32"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT32 -> "UInt32"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT64 -> "Int64"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT64 -> "UInt64"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL -> "Bool"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING -> "String"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_ENUM -> "Enum"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_MESSAGE -> "Struct"
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_GROUP -> error("Proto 'groups' are deprecated and unsupported. https://developers.google.com/protocol-buffers/docs/proto#groups")
            // [PRAG-257]: Support protobuf 'bytes' when we build proto path.
            // For the moment, we use FString for bytes, because IntAny comes across as an escaped json string.
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES -> "String"
            null -> error("Proto type is null: '${field.name}', field number: ${field.number}. This shouldn't happen.")
        }
    }

    protected fun isPrimitiveField(field: DescriptorProtos.FieldDescriptorProto): Boolean {
        return when (field.type) {
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_FLOAT,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT32,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT64,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_ENUM,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL -> true
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES,
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_MESSAGE -> false
            DescriptorProtos.FieldDescriptorProto.Type.TYPE_GROUP -> error("Proto 'groups' are deprecated and unsupported. https://developers.google.com/protocol-buffers/docs/proto#groups")
            // [PRAG-257]: Support protobuf 'bytes' when we build proto path.
            // For the moment, we use FString for bytes, because IntAny comes across as an escaped json string.
            null -> error("Proto type is null: '${field.name}', field number: ${field.number}. This shouldn't happen.")
        }
    }

    fun pragmaProtoJson(): String {
        return pragmaProtoJsonObject().toString()
    }

    fun pragmaProtoJsonObject(): JsonNode {
        return ObjectMapper().valueToTree(qualifiedToNativeType)
    }
}

class PragmaProtoFileWrapper {
    fun read(path: String): String {
        return Path.of(path).normalize().toAbsolutePath().toFile().readText()
    }
}
