package com.jetbrains.lang.parsing.builder

import com.intellij.psi.TokenType
import com.intellij.psi.tree.IElementType
import kala.collection.mutable.MutableArrayList
import kala.collection.mutable.MutableList
import kala.collection.mutable.primitive.MutableIntArrayList
import kala.collection.mutable.primitive.MutableIntList
import org.aya.kala.Int2IntOpenHashMap
import org.aya.kala.Int2ObjectOpenHashMap

class ASTMarkersImpl<T> : ASTMarkers<T> {
  private val packer: Packer
  private val elementTypes: MutableList<IElementType>
  private val errorMessages: Int2ObjectOpenHashMap<String>
  private val chameleonsMap: Int2ObjectOpenHashMap<ChameleonRef<T>> // lexemeIndex -> chameleon

  constructor() {
    packer = Packer()
    elementTypes = MutableArrayList.create(DEFAULT_CAPACITY)
    errorMessages = Int2ObjectOpenHashMap()
    chameleonsMap = Int2ObjectOpenHashMap()
    nextId = 0
  }

  private constructor(origin: ASTMarkersImpl<T>) {
    packer = origin.packer.copy()
    elementTypes = MutableArrayList.from(origin.elementTypes)
    errorMessages = origin.errorMessages.clone()
    chameleonsMap = origin.chameleonsMap.clone()
    nextId = origin.nextId
  }

  private var nextId = 0

  override val size: Int get() = packer.size

  override fun kind(i: Int): Byte {
    return packer.kind(i)
  }

  override fun errorMessage(i: Int): String? {
    return if (hasError(i)) {
      errorMessages[id(i)]
    } else {
      null
    }
  }

  fun hasError(i: Int): Boolean {
    return packer.hasErrors(i)
  }

  fun id(i: Int): Int {
    return packer.id(i)
  }

  override fun lexemeCount(i: Int): Int {
    return packer.lexemeInfo(i).count
  }

  override fun lexemeRelOffset(i: Int): Int {
    return packer.lexemeInfo(i).relOffset
  }


  override fun collapsed(i: Int): Boolean {
    return packer.collapsed(i)
  }

  override fun markersCount(i: Int): Int {
    return packer.markersCount(i)
  }

  override fun elementType(i: Int): IElementType {
    return elementTypes[i]
  }

  override fun chameleonAt(lexemeIndex: Int): ChameleonRef<T> {
    return chameleonsMap[lexemeIndex]
  }

  override fun chameleons(): List<Pair<Int, ChameleonRef<T>>> {
    return chameleonsMap.toList()
  }

  private fun substituteImpl(astMarkers: ASTMarkers<T>, i: Int, lexemeIndex: Int) {
    check(astMarkers is ASTMarkersImpl<T>) { "unexpected class: ${astMarkers.javaClass}" }
    check(kind(i) == MarkerKind.Start)
    val start = i
    val end = i + markersCount(i)
    val relOffset = lexemeRelOffset(i)
    val oldId2newId = Int2IntOpenHashMap()
    copyNewChameleons(lexemeIndex, astMarkers)
    removeMarkersFromMaps(start, end)
    val newNextId = computeIdsForNewMarkers(astMarkers, oldId2newId)

    copyNewAstMarkers(start, end, astMarkers)

    // after several replace operations some ids could be missed.
    // If we are in the short mode, let's renumber our ids
    renumberIfNeeded(newNextId, start, start + astMarkers.size)

    // now we need to replace their ids with our ids
    renumberNewMarkers(start, oldId2newId, astMarkers)

    // fix offset for the root
    val count = lexemeCount(i)
    setLexemeInfo(i, count, relOffset)
    setLexemeInfo(i + astMarkers.size - 1, count, relOffset)
    nextId += newNextId
  }

  private fun copyNewChameleons(startLexeme: Int, astMarkers: ASTMarkersImpl<T>) {
    chameleonsMap.keys
      .filter { startLexeme <= it && it < startLexeme + astMarkers.lexemeCount(0) }
      .forEach { chameleonsMap.remove(it) }
    astMarkers.chameleonsMap.forEach { (k, v) -> chameleonsMap[k + startLexeme] = v }
  }

  private fun renumberNewMarkers(
    start: Int,
    oldId2newId: Int2IntOpenHashMap,
    astMarkers: ASTMarkersImpl<T>,
  ) {
    val offset = nextId
    (start until start + astMarkers.size).forEach { index ->
      val oldId = id(index)
      val newId = oldId2newId[oldId] + offset
      if (hasError(index)) {
        errorMessages[newId] = astMarkers.errorMessages[oldId]
      }
      packer.setId(index, newId)
    }
  }

  private fun computeIdsForNewMarkers(
    astMarkers: ASTMarkersImpl<T>,
    oldId2newId: Int2IntOpenHashMap,
  ): Int {
    var nextNewId = 0
    (0 until astMarkers.size).forEach { index ->
      oldId2newId.computeIfAbsent(astMarkers.id(index)) { nextNewId++ }
    }
    return nextNewId
  }

  private fun renumberIfNeeded(nextNewId: Int, insertedStart: Int, insertedEnd: Int) {
    if (packer.shortMode && nextId + nextNewId > Char.MAX_VALUE.code) {
      nextId = 0
      renumber(insertedStart, insertedEnd)
    }
  }

  private fun copyNewAstMarkers(
    start: Int,
    end: Int,
    astMarkers: ASTMarkersImpl<*>,
  ) {
    elementTypes.removeInRange(start, end + 1)
    elementTypes.insertAll(start, astMarkers.elementTypes)
    packer.replace(start, end + 1, astMarkers.packer)
  }

  private fun removeMarkersFromMaps(start: Int, end: Int) {
    (start..end).forEach { index ->
      val currentId = id(index)
      if (hasError(index)) {
        errorMessages.remove(currentId)
      }
    }
  }

  private fun renumber(insertedStart: Int, insertedEnd: Int) {
    val renumberMap = Int2IntOpenHashMap()
    (0 until size).forEach { index ->
      // skip new elements
      if (index in insertedStart until insertedEnd) return@forEach
      val newId = renumberMap.computeIfAbsent(id(index)) { oldId ->
        val currentId = nextId++
        if (hasError(index)) {
          errorMessages[currentId] = errorMessages.remove(oldId)
        }
        currentId
      }
      packer.setId(index, newId)
    }
  }

  fun setChameleon(lexemeIndex: Int, reference: ChameleonRef<T>) {
    chameleonsMap[lexemeIndex] = reference
  }

  fun setMarkersCount(i: Int, descCount: Int) {
    packer.setMarkersCount(i, descCount)
  }

  fun setLexemeInfo(i: Int, lexemeCount: Int, relOffset: Int) {
    packer.setLexemeInfo(i, lexemeCount, relOffset)
  }

  fun pushBack(): Int {
    val i = size
    packer.pushBack()
    elementTypes.insert(i, TokenType.ERROR_ELEMENT)
    return i
  }

  fun setMarker(
    index: Int,
    id: Int,
    kind: Byte,
    collapsed: Boolean,
    errorMessage: String?,
    elementType: IElementType?,
  ) {
    if (kind(index) != MarkerKind.Undone) {
      throw AssertionError()
    }

    if (id + 1 > nextId) {
      nextId = id + 1
    }

    packer.setInitialInfo(index, id, kind, collapsed, errorMessage != null)

    if (errorMessage != null) {
      this.errorMessages.put(id, errorMessage)
    }

    if (elementType != null) {
      this.elementTypes[index] = elementType
    }
  }

  override fun toString(): String =
    buildString {
      var depth = 0
      elementTypes.forEachIndexed { index, type ->
        if (kind(index) == MarkerKind.End) {
          depth--
        }
        append("${"  ".repeat(depth)}$type ${kind(index)} ")
        append("e=${hasError(index)} ")
        append("c=${collapsed(index)} ")
        append("lo=${lexemeRelOffset(index)} ")
        append("lc${lexemeCount(index)} ")
        appendLine("mc=${markersCount(index)}")
        if (kind(index) == MarkerKind.Start) {
          depth++
        }

      }
      appendLine("{")
      for ((t, u) in chameleonsMap) {
        appendLine("$t -> ${u.get()}")
      }
      appendLine("}")
    }

  override fun mutate(mutator: ASTMarkers.MutableContext<T>.() -> Unit): ASTMarkersImpl<T> =
    MutableContextImpl().also(mutator).ast

  data class LexemeInfo(val relOffset: Int, val count: Int)
  private class Packer {
    // short mode: (kind(2) + collapsed(1) + hasError(1) + markersCount(12) + id(16), lexemeRelOffset(16) + lexemeCount(16) = 2
    // long mode: kind(2) + collapsed(1) + hasError(1) + markersCount(28), id(32), lexemeRelOffset(32), lexemeCount(32) = 4 ints
    private val ints: MutableIntList

    constructor() {
      ints = MutableIntArrayList.create(DEFAULT_CAPACITY)
    }

    private constructor(origin: Packer) {
      longMode = origin.longMode
      ints = MutableIntArrayList.from(origin.ints)
    }

    var longMode = false
      private set

    val shortMode: Boolean
      get() = !longMode

    private fun index(i: Int): Int = if (longMode) i * 4 else i * 2

    fun kind(i: Int): Byte {
      return (ints[index(i)] and KIND_MASK).toByte()
    }

    fun lexemeInfo(i: Int): LexemeInfo {
      val index = index(i)
      return if (longMode) {
        LexemeInfo(ints.get(index + 2), ints.get(index + 3))
      } else {
        val packed = ints.get(index + 1)
        LexemeInfo(packed and MAX_SHORT_LEXEME_VALUE, packed ushr 16)
      }
    }

    fun collapsed(i: Int): Boolean {
      return ints[index(i)] and 4 == 4
    }

    fun hasErrors(i: Int): Boolean {
      return ints[index(i)] and 8 == 8
    }

    fun id(i: Int): Int {
      val index = index(i)
      return if (longMode) ints[index + 1] else ints[index] ushr 16
    }

    fun markersCount(i: Int): Int =
      (ints.get(index(i)) ushr 4) and (if (longMode) MAX_LONG_MARKERS_COUNT else MAX_SHORT_MARKERS_COUNT)


    val size: Int
      get() = ints.size() / (if (longMode) 4 else 2)

    fun setLexemeInfo(i: Int, count: Int, relOffset: Int) {
      if (shortMode && (relOffset > MAX_SHORT_LEXEME_VALUE || count > MAX_SHORT_LEXEME_VALUE)) grow()
      val index = index(i)
      if (longMode) {
        ints.set(index + 2, relOffset)
        ints.set(index + 3, count)
      } else {
        ints.set(index + 1, relOffset or (count shl 16))
      }
    }

    fun setMarkersCount(i: Int, count: Int) {
      if (shortMode && count > MAX_SHORT_MARKERS_COUNT) grow()
      check(count <= MAX_LONG_MARKERS_COUNT) { "markers count $count is bigger than $MAX_LONG_MARKERS_COUNT" }
      val index = index(i)
      ints.set(index, (ints.get(index) and (MAX_SHORT_MARKERS_COUNT shl 4).inv()) or (count shl 4))
    }

    fun setInitialInfo(i: Int, id: Int, kind: Byte, collapsed: Boolean, hasErrors: Boolean) {
      val collapsedInt = if (collapsed) 4 else 0
      val hasErrorInt = if (hasErrors) 8 else 0
      if (shortMode && id > MAX_SHORT_ID_VALUE) grow()
      val index = index(i)
      ints.set(index, kind.toInt() + collapsedInt + hasErrorInt)
      setId(i, id)
    }

    fun setId(i: Int, id: Int) {
      val index = index(i)
      if (longMode) {
        ints.set(index + 1, id)
      } else {
        ints.set(index, (ints.get(index) and Char.MAX_VALUE.code) or (id shl 16))
      }
    }

    fun replace(start: Int, end: Int, packer: Packer) {
      if (longMode && !packer.longMode) packer.grow()
      if (!longMode && packer.longMode) grow()
      ints.removeInRange(index(start), index(end))
      ints.insertAll(index(start), packer.ints)
    }

    private fun grow() {
      check(shortMode) { "Already in long mode" }
      val end = size - 1
      repeat(ints.size()) { ints.append(0) }
      longMode = true
      (end downTo 0).forEach { i ->
        val firstInt = ints[i * 2]
        val secondInt = ints[i * 2 + 1]
        setInitialInfo(
          i,
          firstInt ushr 16,
          (firstInt and 3).toByte(),
          (firstInt and 4) == 4,
          (firstInt and 8) == 8,
        )
        setMarkersCount(i, (firstInt ushr 4) and MAX_SHORT_MARKERS_COUNT)
        setLexemeInfo(i, secondInt ushr 16, secondInt and MAX_SHORT_LEXEME_VALUE)
      }
    }

    fun pushBack() {
      repeat(if (longMode) 4 else 2) { ints.append(0) }
    }

    fun copy(): Packer = Packer(this)

    companion object {
      private const val KIND_MASK = 3
      private const val MAX_SHORT_ID_VALUE = Char.MAX_VALUE.code
      private const val MAX_SHORT_LEXEME_VALUE = Char.MAX_VALUE.code
      private const val MAX_SHORT_MARKERS_COUNT = Char.MAX_VALUE.code ushr 4
      private const val MAX_LONG_MARKERS_COUNT = Int.MAX_VALUE ushr 3
    }
  }

  private inner class MutableContextImpl : ASTMarkers.MutableContext<T> {
    val ast = ASTMarkersImpl(this@ASTMarkersImpl)

    override fun substitute(i: Int, lexemeIndex: Int, astMarkers: ASTMarkers<T>) {
      ast.substituteImpl(astMarkers, i, lexemeIndex)
    }

    override fun changeChameleons(pairs: List<Pair<Int, ChameleonRef<T>>>) {
      ast.chameleonsMap.clear()
      for ((key, value) in pairs) {
        ast.chameleonsMap[key] = value
      }
    }

    override fun changeLexCount(startMarker: Int, endMarker: Int, lexCount: Int) {
      val relOffset = ast.lexemeRelOffset(startMarker)
      ast.setLexemeInfo(startMarker, lexCount, relOffset)
      ast.setLexemeInfo(endMarker, lexCount, relOffset)
    }

    override fun changeMarkerCount(startMarker: Int, endMarker: Int, markerCount: Int) {
      ast.setMarkersCount(startMarker, markerCount)
      ast.setMarkersCount(endMarker, markerCount)
    }
  }

  companion object {
    private const val DEFAULT_CAPACITY = 256
  }
}
