package com.jetbrains.lang.parsing

import com.jetbrains.lang.syntax.CharOffset
import com.jetbrains.lang.syntax.SyntaxNode
import com.jetbrains.lang.syntax.children
import com.intellij.lang.Language
import com.intellij.lexer.Lexer
import com.intellij.psi.ITokenSequence
import com.intellij.psi.TokenType
import com.jetbrains.lang.parsing.builder.ASTMarkers
import com.jetbrains.lang.parsing.builder.MarkerKind
import com.jetbrains.lang.parsing.builder.firstChild
import com.jetbrains.lang.parsing.builder.nextSibling
import com.intellij.psi.tree.IElementType
import com.intellij.psi.tree.ILazyParseableElementType
import com.intellij.psi.tree.IReparseableElementType
import com.jetbrains.lang.parsing.builder.*
import java.util.Objects
import java.util.concurrent.atomic.AtomicReference

typealias AstMarkersBuilderFactory = (
  ijLanguage: Language?,
  text: CharSequence,
  tokens: ITokenSequence,
  startLexeme: Int,
  lexemeCount: Int,
) -> AstMarkersPsiBuilder

typealias AstMarkers = ASTMarkers<AstMarkersChameleon>
typealias AstMarkersPsiBuilder = MarkerPsiBuilder<AstMarkersChameleon>
typealias CancellationToken = () -> Unit

data class AstMarkersChameleon(
  val customLexemeStore: ITokenSequence?,
  val ast: AstMarkers,
)

class AstMarkersSyntaxNode internal constructor(
  internal val parent: AstMarkersSyntaxNode?,
  internal val prevSibling: AstMarkersSyntaxNode?,
  internal val context: WalkerContext,
  val tokens: ITokenSequence,
  internal val startLexemeIndex: Int,
  internal val nextMarkerStartLexemeIndex: Int,
  internal val markerIndex: Int,
) : SyntaxNode {

  companion object {
    private fun rootWithContext(context: WalkerContext, tokens: ITokenSequence): AstMarkersSyntaxNode =
      AstMarkersSyntaxNode(
        parent = null,
        prevSibling = null,
        context = context,
        markerIndex = 0,
        startLexemeIndex = context.startLexemeIndex,
        tokens = tokens,
        nextMarkerStartLexemeIndex = context.startLexemeIndex,
      )

    fun root(
      text: CharSequence,
      markers: AstMarkers,
      lexer: Lexer,
      tokenizationPolicy: TokenizationPolicy,
      builderFactory: AstMarkersBuilderFactory,
      tokens: ITokenSequence,
    ): AstMarkersSyntaxNode =
      rootWithContext(
        WalkerContext(
          text = text,
          ast = markers,
          lexer = lexer,
          tokenization = tokenizationPolicy,
          builderFactory = builderFactory,
        ),
        tokens.fork(),
      )

  }

  private fun copy(
    parent: AstMarkersSyntaxNode? = this.parent,
    prevSibling: AstMarkersSyntaxNode? = this.prevSibling,
    context: WalkerContext = this.context,
    startLexemeIndex: Int = this.startLexemeIndex,
    nextMarkerStartLexemeIndex: Int = this.nextMarkerStartLexemeIndex,
    markerIndex: Int = this.markerIndex,
  ): AstMarkersSyntaxNode =
    AstMarkersSyntaxNode(
      parent = parent,
      prevSibling = prevSibling,
      context = context,
      startLexemeIndex = startLexemeIndex,
      nextMarkerStartLexemeIndex = nextMarkerStartLexemeIndex,
      tokens = tokens.fork(),
      markerIndex = markerIndex,
    )

  internal val isMarker = markerIndex != -1 && startLexemeIndex == nextMarkerStartLexemeIndex
  internal val endLexemeIndex = startLexemeIndex + (if (isMarker) context.ast.lexemeCount(markerIndex) else 1)

  override val type: Any get() = elementType

  val elementType: IElementType =
    when {
      isMarker -> context.ast.elementType(markerIndex)
      else -> tokens.lexType(startLexemeIndex)
    }

  override val language: Language
    get() {
      val currentLang = elementType.language
      val specificLang = generateSequence(this) { it.parent }.mapNotNull {
        if (it.elementType.language.isKindOf(currentLang)) it.elementType.language else null
      }.lastOrNull()
      return specificLang ?: currentLang
    }

  override fun equals(other: Any?): Boolean =
    (other === this) ||
      (other is AstMarkersSyntaxNode &&
        other.tokens == tokens &&
        other.markerIndex == markerIndex &&
        other.startLexemeIndex == startLexemeIndex)

  override fun hashCode(): Int =
    Objects.hash(context, markerIndex, startLexemeIndex)

  override val text: CharSequence
    get() = context.text.subSequence((startOffset - context.offset).toInt(), (endOffset - context.offset).toInt())

  override val startOffset: CharOffset
    get() =
      context.offset +
        when {
          startLexemeIndex == tokens.lexemeCount -> context.text.length.toLong()
          else -> tokens.lexStart(startLexemeIndex).toLong()
        }

  override val endOffset: CharOffset
    get() =
      context.offset +
        when {
          isMarker ->
            when (context.ast.kind(markerIndex)) {
              MarkerKind.Start, MarkerKind.Error ->
                tokens.endCharAt(endLexemeIndex - 1)

              MarkerKind.End -> error("should not be at the end")
              else -> error("no else")
            }

          else -> tokens.endCharAt(startLexemeIndex)
        }.toLong()

  override val errorMessage: String?
    get() =
      (if (isMarker) context.ast.errorMessage(markerIndex) else null)
        ?: (if (type == TokenType.ERROR_ELEMENT) "Bad token" else null)


  override fun parent(): SyntaxNode? = parent

  override fun firstChild(): SyntaxNode? {
    return when {
      isMarker -> {
        val ast = context.ast
        when {
          ast.kind(markerIndex) == MarkerKind.Error -> null
          isChameleon() ->
            (chameleonSyntaxNode?.firstChild() as AstMarkersSyntaxNode?)?.copy(parent = this)

          else -> {
            val childMarkerIndex = ast.firstChild(markerIndex)
            when {
              childMarkerIndex == -1 && startLexemeIndex != endLexemeIndex ->
                copy(
                  markerIndex = -1,
                  nextMarkerStartLexemeIndex = startLexemeIndex,
                  prevSibling = null,
                  parent = this,
                )

              childMarkerIndex == -1 -> null
              ast.kind(childMarkerIndex) == MarkerKind.Start || ast.kind(childMarkerIndex) == MarkerKind.Error -> {
                copy(
                  markerIndex = childMarkerIndex,
                  prevSibling = null,
                  nextMarkerStartLexemeIndex = startLexemeIndex + ast.lexemeRelOffset(childMarkerIndex),
                  parent = this,
                )
              }

              ast.kind(childMarkerIndex) == MarkerKind.End -> error("should not be at the end")
              else -> error("no else")
            }
          }
        }
      }

      isChameleon() && !isCopyOfParent() ->
        (chameleonSyntaxNode?.firstChild() as AstMarkersSyntaxNode?)?.copy(parent = this)

      else -> null
    }
  }

  private fun isCopyOfParent(): Boolean =
    parent != null &&
      parent.type == this.type &&
      parent.startLexemeIndex == this.startLexemeIndex &&
      parent.endLexemeIndex == this.endLexemeIndex &&
      parent.markerIndex == this.markerIndex

  internal fun isChameleon(): Boolean =
    (isMarker &&
      context.ast.kind(markerIndex) == MarkerKind.Start &&
      context.ast.elementType(markerIndex) is ILazyParseableElementType &&
      context.ast.collapsed(markerIndex)) ||
      (!isMarker && type is ILazyParseableElementType)

  override fun lastChild(): SyntaxNode? {
    tailrec fun lastSibling(c: SyntaxNode?): SyntaxNode? =
      when (val s = c?.nextSibling()) {
        null -> c
        else -> lastSibling(s)
      }
    return lastSibling(firstChild())
  }

  override fun nextSibling(): AstMarkersSyntaxNode? =
    when {
      parent == null -> null
      isMarker -> {
        val siblingMarkerIndex = context.ast.nextSibling(markerIndex)
        val startLexemeIndex = when {
          siblingMarkerIndex == -1 -> endLexemeIndex
          else -> endLexemeIndex + context.ast.lexemeRelOffset(siblingMarkerIndex)
        }
        goForthToNextSibling(
          siblingLexemeIndex = endLexemeIndex,
          startLexemeIndex = startLexemeIndex,
          markerIndex = siblingMarkerIndex,
        )
      }

      else ->
        goForthToNextSibling(
          siblingLexemeIndex = startLexemeIndex + 1,
          startLexemeIndex = when {
            markerIndex != -1 -> nextMarkerStartLexemeIndex
            else -> startLexemeIndex + 1
          },
          markerIndex = markerIndex,
        )
    }

  override fun prevSibling(): SyntaxNode? = prevSibling

  private fun goForthToNextSibling(
    siblingLexemeIndex: Int,
    startLexemeIndex: Int,
    markerIndex: Int,
  ): AstMarkersSyntaxNode? {
    // do we change token set in chameleon node?
    val parentEndLexeme = if (parent?.tokens != tokens) tokens.lexemeCount else parent.endLexemeIndex
    return when {
      markerIndex != -1 || siblingLexemeIndex < parentEndLexeme ->
        copy(
          startLexemeIndex = siblingLexemeIndex,
          nextMarkerStartLexemeIndex = startLexemeIndex,
          markerIndex = markerIndex,
          prevSibling = this,
        )

      else -> null
    }
  }

  override fun childByOffset(offset: CharOffset): SyntaxNode? =
    children().firstOrNull { offset in it.startOffset until it.endOffset }

  private val chameleonSyntaxNode: SyntaxNode?
    get() {
      val (chameleonTokens, tree) = getOrParseChameleon()
      return when {
        tree.size == 0 -> null
        chameleonTokens == null -> {
          rootWithContext(context.copy(ast = tree, startLexemeIndex = startLexemeIndex), tokens.fork())
        }

        else -> {
          // we should use lexemes for getting char offsets, because startOffset is a global offset in the whole file
          val startOffsetInContext = tokens.lexStart(startLexemeIndex)
          val endOffsetInContext = tokens.endCharAt(endLexemeIndex - 1)
          val chamaleonText = context.text.subSequence(startOffsetInContext, endOffsetInContext)
          val newContextLexer = (type as ILazyParseableElementType).createInnerLexer() ?: context.lexer
          rootWithContext(
            context.copy(
              text = chamaleonText,
              lexer = newContextLexer,
              startLexemeIndex = 0,
              offset = startOffset,
              ast = tree,
            ),
            chameleonTokens.fork(),
          )
        }
      }
    }

  private fun getOrParseChameleon(): AstMarkersChameleon {
    val c = try {
      context.ast.chameleonAt(startLexemeIndex - context.startLexemeIndex)
    } catch (e: NullPointerException) {
      throw IllegalStateException("Chameleon not found")
    }
    return c.get() ?: synchronized(c) {
      c.get() ?: run {
        val (tokens, node) = parseChameleon(
          text = context.text,
          contextLexer = context.lexer,
          lexemeStore = tokens,
          factory = context.builderFactory,
          startLexeme = startLexemeIndex,
          lexemeCount = endLexemeIndex - startLexemeIndex,
          cancellationToken = {},
        )
        AstMarkersChameleon(tokens, node).also { c.set(it) }
      }
    }
  }

  private fun parseChameleon(
    text: CharSequence,
    contextLexer: Lexer,
    lexemeStore: ITokenSequence,
    factory: AstMarkersBuilderFactory,
    startLexeme: Int,
    lexemeCount: Int,
    cancellationToken: CancellationToken,
  ): Pair<ITokenSequence?, AstMarkers> {
    val lexer = (elementType as ILazyParseableElementType).createInnerLexer()
    val chameleonText: CharSequence
    val chameleonTokens: ITokenSequence
    val chameleonStartLexeme: Int
    val chameleonLexemeCount: Int
    val lexerChanged = lexer != null && !lexer.isEquivalent(contextLexer)
    when {
      lexerChanged -> {
        val startOffsetInContext = lexemeStore.lexStart(startLexeme)
        val endOffsetInContext = lexemeStore.endCharAt(startLexeme + lexemeCount - 1)
        chameleonText = text.subSequence(startOffsetInContext, endOffsetInContext)
        chameleonTokens = context.tokenization.tokenize(
          text = chameleonText,
          emptyType = TokenType.WHITE_SPACE,
          lexer = lexer!!,
          cancellationToken = cancellationToken,
        )
        chameleonStartLexeme = 0
        chameleonLexemeCount = chameleonTokens.lexemeCount
      }

      else -> {
        chameleonText = text
        chameleonTokens = lexemeStore
        chameleonStartLexeme = startLexeme
        chameleonLexemeCount = lexemeCount
      }
    }
    val builder = factory(language, chameleonText, chameleonTokens, chameleonStartLexeme, chameleonLexemeCount)
    elementType.parse(builder)
    return chameleonTokens.takeIf { lexerChanged } to builder.root
  }

  @Suppress("UNUSED")
  fun reportState(): String =
    buildString {
      appendLine("context offset : ${context.offset}")
      appendLine("context startLexemeIndex : ${context.startLexemeIndex}")
      appendLine("file text")
      appendLine(
        """```
      ${context.text}
      ```""".trimIndent(),
      )
      appendLine("ast tree :")
      appendLine(
        """```
      ${context.ast}
      ```""".trimIndent(),
      )
      var node: AstMarkersSyntaxNode? = this@AstMarkersSyntaxNode
      while (node != null) {
        appendLine(node)
        node = node.parent
      }
    }

  override fun toString(): String =
    "$type [m=$markerIndex, l=$startLexemeIndex], (sl=$nextMarkerStartLexemeIndex, el=$endLexemeIndex) " +
      if (isMarker) "mc=" + context.ast.markersCount(markerIndex) else ""

  private fun findDeepestReparseableNode(
    rootLexer: Lexer,
    newTokens: ITokenSequence,
    newText: CharSequence,
    startOffset: Long,
    endOffset: Long,
    cancellationToken: CancellationToken,
  ): Triple<AstMarkersSyntaxNode?, ITokenSequence, Boolean> {
    fun AstMarkersSyntaxNode.asReparseableNode(newTokens: ITokenSequence): AstMarkersSyntaxNode? {
      val newType = when {
        startOffset >= 0 && startOffset < newTokens.textLength -> {
          newTokens.lexType(newTokens.lexemeIndexByChar(startOffset.toInt()))
        }

        else -> null
      }
      return when {
        newType == type -> this
        else -> null
      }
    }

    var node: AstMarkersSyntaxNode? = this
    var currentTokens = newTokens
    var nodeTokens = currentTokens
    var deepestReparseableNode: AstMarkersSyntaxNode? = null

    var currentLexer = rootLexer
    while (node != null && node.startOffset <= startOffset && node.endOffset > endOffset) {
      val type = node.type
      val parentType = node.parent?.elementType
      val lexer = (type as? ILazyParseableElementType)?.createInnerLexer()
      val nodeText = newText.subSequence(node.startOffset.toInt(), node.endOffset.toInt())
      if (type is IReparseableElementType) {
        val startLexemeIndex = currentTokens.lexemeIndexByChar((node.startOffset - node.context.offset).toInt())
        val endLexemeIndex = currentTokens.lexemeIndexByChar((node.endOffset - node.context.offset).toInt())
        if (type.isParsable(parentType, nodeText, currentTokens, startLexemeIndex, endLexemeIndex - startLexemeIndex)) {
          deepestReparseableNode = node
          nodeTokens = currentTokens
        } else {
          val remapped = nodeTokens
          return Triple(deepestReparseableNode?.asReparseableNode(remapped), remapped, currentLexer != rootLexer)
        }
      }
      if (lexer != null && !lexer.isEquivalent(currentLexer)) {
        currentLexer = lexer
        currentTokens = context.tokenization.tokenize(
          text = nodeText,
          emptyType = TokenType.WHITE_SPACE,
          lexer = lexer,
          cancellationToken = cancellationToken,
        )
      }
      node = node.children().firstOrNull { it.startOffset <= startOffset && it.endOffset >= endOffset } as AstMarkersSyntaxNode?
    }
    return Triple(deepestReparseableNode?.asReparseableNode(nodeTokens), nodeTokens, currentLexer != rootLexer)
  }

  fun tryReparse(
    builderFactory: AstMarkersBuilderFactory,
    lexer: Lexer,
    newTokens: ITokenSequence,
    text: CharSequence,
    start: Long,
    end: Long,
    cancellationToken: CancellationToken,
  ): Pair<AstMarkers, ITokenSequence>? {
    val (reparseableNode, nodeTokens, lexerChanged) = findDeepestReparseableNode(
      rootLexer = lexer,
      newTokens = newTokens,
      newText = text,
      startOffset = start,
      endOffset = text.length - end,
      cancellationToken = cancellationToken,
    )
    return if (reparseableNode != null) {
      val oldTokens = reparseableNode.tokens
      val startLexemeIndex = nodeTokens.lexemeIndexByChar((reparseableNode.startOffset - reparseableNode.context.offset).toInt())
      val endLexemeIndex = nodeTokens.lexemeIndexByChar((reparseableNode.endOffset - reparseableNode.context.offset).toInt())
      val oldLexemeCount = endLexemeIndex - startLexemeIndex

      val diff = nodeTokens.lexemeCount - oldTokens.lexemeCount
      for (i in 0 until startLexemeIndex) {
        nodeTokens.remap(i, oldTokens.lexType(i))
      }

      for (i in endLexemeIndex until nodeTokens.lexemeCount) {
        nodeTokens.remap(i, oldTokens.lexType(i - diff))
      }

      val (lexemeStore, node) = reparseableNode.parseChameleon(
        text = text,
        contextLexer = lexer,
        lexemeStore = nodeTokens,
        factory = builderFactory,
        startLexeme = startLexemeIndex,
        lexemeCount = oldLexemeCount,
        cancellationToken = cancellationToken,
      )

      val newLexemeStorage = nodeTokens.takeUnless { lexerChanged } ?: newTokens
      substitute(reparseableNode, nodeTokens, AstMarkersChameleon(lexemeStore, node)) to newLexemeStorage
    } else null
  }
}

private fun Lexer.isEquivalent(contextLexer: Lexer): Boolean =
  this.equals(contextLexer)

private fun substitute(
  astMarkersSyntaxNode: AstMarkersSyntaxNode,
  nodeTokens: ITokenSequence,
  parsedChameleon: AstMarkersChameleon,
): ASTMarkers<AstMarkersChameleon> {
  var newNodeTokensAdded = false
  var parent = astMarkersSyntaxNode.parent
  val newLexemeCount = parsedChameleon.ast.lexemeCount(0)
  val oldLexemeCount = astMarkersSyntaxNode.endLexemeIndex - astMarkersSyntaxNode.nextMarkerStartLexemeIndex
  var diff = if (parsedChameleon.customLexemeStore == null) newLexemeCount - oldLexemeCount else 0
  while (parent?.context == astMarkersSyntaxNode.context) parent = parent.parent
  val pairs = astMarkersSyntaxNode.context.ast.chameleons().map { (key, value) ->
    val contextKey = key + astMarkersSyntaxNode.context.startLexemeIndex
    val newKey = if (contextKey >= astMarkersSyntaxNode.endLexemeIndex) key + diff else key
    if (contextKey == astMarkersSyntaxNode.startLexemeIndex) newKey to AtomicReference(parsedChameleon) else newKey to value
  }
  var reparsedAst = astMarkersSyntaxNode.context.ast.mutate {
    changeChameleons(pairs)
    if (astMarkersSyntaxNode.isChameleon()) {
      if (astMarkersSyntaxNode.isMarker && parsedChameleon.customLexemeStore == null) {
        updateLexemes(astMarkersSyntaxNode, diff)
      }
    } else {
      updateLexemesAndMarkers(parsedChameleon, astMarkersSyntaxNode, diff)
      substitute(
        astMarkersSyntaxNode.markerIndex,
        astMarkersSyntaxNode.startLexemeIndex - astMarkersSyntaxNode.context.startLexemeIndex,
        parsedChameleon.ast,
      )
    }
  }

  while (parent != null) {
    if (parent.isChameleon()) {
      val replacedChameleon = parent.context.ast.chameleons()
        .sortedBy { (lexemeIndex, _) -> lexemeIndex }
        .map { (lexemeIndex, ref) ->
          when {
            lexemeIndex < parent!!.startLexemeIndex - parent!!.context.startLexemeIndex -> {
              lexemeIndex to ref
            }

            lexemeIndex == parent!!.startLexemeIndex - parent!!.context.startLexemeIndex -> {
              val oldParsedChameleon = ref.get()
              checkNotNull(oldParsedChameleon) { "parent should be parsed " }
              val theSameTokens = oldParsedChameleon.customLexemeStore == null
              val newTokens = if (!theSameTokens && !newNodeTokensAdded) {
                newNodeTokensAdded = true
                nodeTokens
              } else oldParsedChameleon.customLexemeStore
              val newParsedChameleon = AstMarkersChameleon(newTokens, ast = reparsedAst)
              diff = if (theSameTokens) diff else 0
              lexemeIndex to AtomicReference(newParsedChameleon)
            }

            else -> {
              (lexemeIndex + diff) to ref
            }
          }
        }
      reparsedAst = parent.context.ast.mutate {
        changeChameleons(replacedChameleon)
        if (parent!!.isMarker && diff != 0) {
          updateLexemes(parent!!, diff)
        }
      }
    }
    parent = parent.parent
  }
  return reparsedAst
}

internal fun ITokenSequence.endCharAt(lexemeIndex: Int): Int =
  when {
    lexemeCount <= lexemeIndex -> error("lexemeIndex $lexemeIndex is out of bounds, lexemeCount: $lexemeCount")
    lexemeIndex == lexemeCount - 1 -> textLength
    else -> lexStart(lexemeIndex + 1)
  }


fun interface TokenizationPolicy {
  fun tokenize(text: CharSequence, emptyType: IElementType, lexer: Lexer, cancellationToken: CancellationToken): ITokenSequence
}

internal data class WalkerContext(
  val text: CharSequence,
  val ast: ASTMarkers<AstMarkersChameleon>,
  val lexer: Lexer,
  val tokenization: TokenizationPolicy,
  val builderFactory: AstMarkersBuilderFactory,
  val startLexemeIndex: Int = 0,
  val offset: CharOffset = 0,
)

private fun ASTMarkers.MutableContext<AstMarkersChameleon>.updateLexemesAndMarkers(
  parsedChameleon: AstMarkersChameleon,
  reparseableNode: AstMarkersSyntaxNode,
  diff: Int,
) {
  updateLexemes(reparseableNode, diff)
  val newMarkers = parsedChameleon.ast.markersCount(0)
  val oldMarkers = reparseableNode.context.ast.markersCount(reparseableNode.markerIndex)
  var node: AstMarkersSyntaxNode? = reparseableNode.parent
  while (node != null && node.context.ast == reparseableNode.context.ast) {
    val startMarkerIndex = node.markerIndex
    val endMarkerIndex = startMarkerIndex + node.context.ast.markersCount(node.markerIndex)
    val prevMarkersCount = node.context.ast.markersCount(startMarkerIndex)
    changeMarkerCount(startMarkerIndex, endMarkerIndex, prevMarkersCount - oldMarkers + newMarkers)
    node = node.parent
  }
  changeMarkerCount(
    0,
    reparseableNode.context.ast.markersCount(0),
    reparseableNode.context.ast.markersCount(0) - oldMarkers + newMarkers,
  )
}

private fun ASTMarkers.MutableContext<AstMarkersChameleon>.updateLexemes(
  reparseableNode: AstMarkersSyntaxNode,
  diff: Int,
) {
  var node: AstMarkersSyntaxNode? = reparseableNode
  while (node != null && node.context.ast == reparseableNode.context.ast) {
    updateLexemeCount(node.markerIndex, node.context.ast, diff)
    node = node.parent
  }
  updateLexemeCount(0, reparseableNode.context.ast, diff)
}

private fun ASTMarkers.MutableContext<AstMarkersChameleon>.updateLexemeCount(
  startMarkerIndex: Int,
  ast: ASTMarkers<*>,
  diff: Int,
) {
  val endMarkerIndex = startMarkerIndex + ast.markersCount(startMarkerIndex)
  val prevLexCount = ast.lexemeCount(startMarkerIndex)
  changeLexCount(startMarkerIndex, endMarkerIndex, prevLexCount + diff)
}
