package pragma.protoc.plugin.custom
import java.util.Stack

/**
 * Directed graph.
 * Vertices are unique within the graph based on their value.
 * Edges are unique based on the values of the vertices.
 * Edges both ways are supported.
 * Reflexive edges are supported.
 */
class DirectedUniqueGraph<T> {
    /**
     * Unique within the graph by its value. Equality and hash determined by value.
     * Edges to 'next' are also unique by this same property.
     */
    private class Vertex<T>(val value: T) {
        private val edges = mutableSetOf<Vertex<T>>()
        fun edges(): Set<Vertex<T>> = edges

        fun addNext(vertex: Vertex<T>) = edges.add(vertex)

        // Generated by intellij wizard.
        override fun equals(other: Any?): Boolean {
            if (this === other) return true
            if (javaClass != other?.javaClass) return false

            other as Vertex<*>

            if (value != other.value) return false

            return true
        }
        // Generated by intellij wizard.
        override fun hashCode(): Int {
            return value?.hashCode() ?: 0
        }
    }

    private val vertices = mutableMapOf<T, Vertex<T>>()
    val size: Int
        get() = vertices.size

    /**
     * Add a vertex to the graph. Note that addEdge will add any vertices specified,
     * but this is the only way to add an edgeless vertex.
     */
    fun addVertex(value: T) {
        getOrAddVertex(value)
    }

    /**
     * Connects two vertices in the direction specified. Vertices will be created if they don't exist.
     */
    fun addEdge(from: T, to: T) = getOrAddVertex(from).addNext(getOrAddVertex(to))

    fun containsVertex(value: T): Boolean = vertices.containsKey(value)
    fun containsEdge(from: T, to: T): Boolean = vertices[from]?.edges()?.contains(Vertex(to)) ?: false

    /**
     * Class returned by topologicalSort.
     * If isCyclical is false, values will be the topologically sorted list.
     * If isCyclical is true, values will be set to the recurse stack in which the cycle was found.
     */
    data class TopologicalSortResult<T>(val values: List<T>, val cycles: List<List<T>>)

    /**
     * Determines how the topological sort proceeds when cycles are found in the graph during sorting.
     */
    enum class CycleHandlingStrategy {
        /**
         * Stop the sorting and return the list of values that contain have the cycle.
         */
        ErrorAndReturnRecurseStack,

        /**
         * Break the cyclic recurse and continue sorting.
         *
         * Note that this strategy doesn't guarantee any particular ordering, for example in the graph below:
         *     X
         *     ↓
         *     A
         *   ↗ ↓ ↖
         * D ← B → C
         *
         * The sorted list may return either:
         *      X - A - B - C - D
         *      X - A - B - D - C
         * Depending on which edge the algorithm happens to take.
         *
         * The result will be deterministic given the same input.
         */
        Continue,
    }

    /**
     * Run a topological sort on the graph.
     * If successful, result.values will contain a list sorted by:
     * "a linear ordering of its vertices such that for every directed edge uv from vertex u to vertex v, u comes before v in the ordering."
     * (Definition: Wikipedia) https://en.wikipedia.org/wiki/Topological_sorting
     *
     * The sorting is for equivalent elements is arbitrary, meaning there can be multiple orderings for the same
     * graph, as long as that ordering is still topological.
     *
     * When cycles are encountered, the CycleHandlingStrategy determines how it proceeds.
     */
    fun topologicalSort(cycleHandlingStrategy: CycleHandlingStrategy): TopologicalSortResult<T> {
        // Visited ensures we only ever visit each vertex once.
        val visited = mutableSetOf<Vertex<T>>()
        // The stack of values we'll walk to generate the final ordering.
        val outStack = Stack<T>()
        var cycles = Cycles<T>(mutableListOf())
        for (vertex in vertices.values) {
            // A stack of vertices visited _this recurse_. If we see something in here, it means we've found a cycle.
            val recurseStack = Stack<Vertex<T>>()
            if (visited.add(vertex)) {
                cycles.merge(topologicalSortRecurse(cycleHandlingStrategy, vertex, visited, recurseStack, outStack))
                if (cycles.hasCycles() && cycleHandlingStrategy == CycleHandlingStrategy.ErrorAndReturnRecurseStack) {
                    return TopologicalSortResult(recurseStack.map { it.value }, cycles.toOutputList())
                }
            }
        }
        val sorted = mutableListOf<T>()
        while (!outStack.empty()) {
            sorted.add(outStack.pop())
        }
        return TopologicalSortResult(sorted, cycles.toOutputList())
    }

    private class Cycles<T>(val recurseStacks: MutableList<List<Vertex<T>>>) {
        fun hasCycles() = recurseStacks.isNotEmpty()

        fun merge(other: Cycles<T>) {
            recurseStacks.addAll(other.recurseStacks)
        }

        fun toOutputList(): List<List<T>> {
            return recurseStacks.map { stack -> stack.map { vertex -> vertex.value } }
        }
    }

    private fun topologicalSortRecurse(cycleHandlingStrategy: CycleHandlingStrategy, vertex: Vertex<T>, visited: MutableSet<Vertex<T>>, recurseStack: Stack<Vertex<T>>, outStack: Stack<T>): Cycles<T> {
        return when (cycleHandlingStrategy) {
            CycleHandlingStrategy.ErrorAndReturnRecurseStack -> {
                topologicalSortRecurse_ErrorOnCycle(vertex, visited, recurseStack, outStack)
            }

            CycleHandlingStrategy.Continue -> {
                topologicalSortRecurse_ContinueOnCycle(vertex, visited, recurseStack, outStack)
            }
        }
    }

    private fun topologicalSortRecurse_ContinueOnCycle(vertex: Vertex<T>, visited: MutableSet<Vertex<T>>, recurseStack: Stack<Vertex<T>>, outStack: Stack<T>): Cycles<T> {
        recurseStack.push(vertex)
        val cycles = Cycles<T>(mutableListOf())
        for (adjacentVertex in vertex.edges()) {
            if (recurseStack.contains(adjacentVertex)) {
                // Add the cycling vertex for clarity.
                recurseStack.push(adjacentVertex)
                cycles.recurseStacks.add(recurseStack.toList())
                recurseStack.pop()
                continue
            }
            if (visited.add(adjacentVertex)) {
                cycles.merge(topologicalSortRecurse_ContinueOnCycle(adjacentVertex, visited, recurseStack, outStack))
            }
        }
        recurseStack.pop()
        outStack.push(vertex.value)
        return cycles
    }

    private fun topologicalSortRecurse_ErrorOnCycle(vertex: Vertex<T>, visited: MutableSet<Vertex<T>>, recurseStack: Stack<Vertex<T>>, outStack: Stack<T>): Cycles<T> {
        recurseStack.push(vertex)
        for (adjacentVertex in vertex.edges()) {
            if (recurseStack.contains(adjacentVertex)) {
                // Found a cycle, abort. Include the vertex in the recurse stack so the loop is clear.
                recurseStack.push(adjacentVertex)
                return Cycles<T>(mutableListOf(recurseStack))
            }
            if (visited.add(adjacentVertex)) {
                val cycles = topologicalSortRecurse_ErrorOnCycle(adjacentVertex, visited, recurseStack, outStack)
                if (cycles.hasCycles()) {
                    // Cycle found in children, abort.
                    return cycles
                }
            }
        }
        recurseStack.pop()
        outStack.push(vertex.value)
        return Cycles(mutableListOf())
    }

    private fun getOrAddVertex(value: T): Vertex<T> {
        val vertex = vertices[value] ?: Vertex(value)
        vertices[value] = vertex
        return vertex
    }
}
