package com.intellij.psi.builder;

import org.aya.ij.AyaModified;
import com.intellij.openapi.util.Pair;
import com.intellij.psi.TokenType;
import com.intellij.psi.tree.IElementType;
import kala.collection.mutable.MutableArrayList;
import kala.collection.mutable.MutableList;
import kala.collection.mutable.primitive.MutableIntArrayList;
import kala.collection.mutable.primitive.MutableIntList;
import kala.range.primitive.IntRange;
import kala.tuple.Unit;
import kala.value.Var;
import kala.value.primitive.MutableIntValue;
import org.aya.kala.Int2IntOpenHashMap;
import org.aya.kala.Int2ObjectOpenHashMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.IntStream;

/** rewrite of ASTMarkersImpl.kt */
@AyaModified
public class ASTMarkersImpl<T> implements ASTMarkers<T> {
  private static final int DEFAULT_CAPACITY = 256;
  private final @NotNull Packer packer;
  private final @NotNull MutableList<IElementType> elementTypes;
  private final @NotNull Int2ObjectOpenHashMap<String> errorMessages;
  private final @NotNull Int2ObjectOpenHashMap<AtomicReference<T>> chameleonsMap; // lexemeIndex -> chameleon
  private int nextId;

  public ASTMarkersImpl() {
    packer = new Packer();
    elementTypes = MutableArrayList.create(DEFAULT_CAPACITY);
    errorMessages = new Int2ObjectOpenHashMap<>();
    chameleonsMap = new Int2ObjectOpenHashMap<>();
    nextId = 0;
  }

  private ASTMarkersImpl(@NotNull ASTMarkersImpl<T> origin) {
    packer = origin.packer.copy();
    elementTypes = MutableArrayList.from(origin.elementTypes);
    errorMessages = origin.errorMessages.clone();
    chameleonsMap = origin.chameleonsMap.clone();
    nextId = origin.nextId;
  }

  @Override public int getSize() {
    return packer.size();
  }

  @Override public byte kind(int i) {
    return packer.kind(i);
  }

  @Override public @Nullable String errorMessage(int i) {
    return hasError(i) ? errorMessages.get(id(i)) : null;
  }

  boolean hasError(int i) {
    return packer.hasError(i);
  }

  int id(int i) {
    return packer.id(i);
  }

  @Override public int lexemeCount(int i) {
    return packer.lexemeInfo(i).count;
  }

  @Override public int lexemeRelOffset(int i) {
    return packer.lexemeInfo(i).relOffset;
  }

  @Override public boolean collapsed(int i) {
    return packer.collapsed(i);
  }

  @Override public int markersCount(int i) {
    return packer.markersCount(i);
  }

  @Override public @NotNull IElementType elementType(int i) {
    return elementTypes.get(i);
  }

  @Override public @NotNull AtomicReference<T> chameleonAt(int lexemeIndex) {
    return chameleonsMap.get(lexemeIndex);
  }

  @Override public @NotNull List<Pair<Integer, AtomicReference<T>>> chameleons() {
    return chameleonsMap.toList();
  }

  private void substituteImpl(@NotNull ASTMarkers<T> preAstMarkers, int i, int lexemeIndex) {
    if (!(preAstMarkers instanceof ASTMarkersImpl<T> astMarkers))
      throw new AssertionError("unexpected class: " + preAstMarkers.getClass());
    ASTMarkers.check(kind(i) == MarkerKind.Start, "");

    var start = i;
    var end = i + markersCount(i);
    var relOffset = lexemeRelOffset(i);
    var oldId2newId = new Int2IntOpenHashMap();
    copyNewChameleons(lexemeIndex, astMarkers);
    removeMarkersFromMaps(start, end);
    var newNextId = computeIdsForNewMarkers(astMarkers, oldId2newId);

    copyNewAstMarkers(start, end, astMarkers);

    // after several replace operations some ids could be missed.
    // If we are in the short mode, let's renumber our ids
    renumberIfNeeded(newNextId, start, start + astMarkers.getSize());

    // now we need to replace their ids with our ids
    renumberNewMarkers(start, oldId2newId, astMarkers);

    // fix offset for the root
    var count = lexemeCount(i);
    setLexemeInfo(i, count, relOffset);
    setLexemeInfo(i + astMarkers.getSize() - 1, count, relOffset);
    nextId += newNextId;
  }

  private void copyNewChameleons(int startLexeme, @NotNull ASTMarkersImpl<T> astMarkers) {
    chameleonsMap.keySet().stream()
      .filter(it -> startLexeme <= it && it < startLexeme + astMarkers.lexemeCount(0))
      .forEach(chameleonsMap::remove);
    astMarkers.chameleonsMap.forEach((k, v) -> chameleonsMap.put(k + startLexeme, v));
  }

  private void renumberNewMarkers(
    int start,
    @NotNull Int2IntOpenHashMap oldId2newId,
    @NotNull ASTMarkersImpl<T> astMarkers
  ) {
    var offset = nextId;
    IntStream.range(start, start + astMarkers.getSize()).forEach(index -> {
      var oldId = id(index);
      var newId = oldId2newId.get(oldId) + offset;
      if (hasError(index)) {
        errorMessages.put(newId, astMarkers.errorMessages.get(oldId));
      }
      packer.setId(index, newId);
    });
  }

  private int computeIdsForNewMarkers(
    @NotNull ASTMarkersImpl<T> astMarkers,
    @NotNull Int2IntOpenHashMap oldId2newId
  ) {
    var nextNewId = new Var<>(0);
    IntStream.range(0, astMarkers.getSize()).forEach(index -> oldId2newId.computeIfAbsent(astMarkers.id(index), i -> nextNewId.value++));
    return nextNewId.value;
  }

  private void renumberIfNeeded(int nextNewId, int insertedStart, int insertedEnd) {
    if (packer.shortMode() && nextId + nextNewId > ((int) Character.MAX_VALUE)) {
      nextId = 0;
      renumber(insertedStart, insertedEnd);
    }
  }

  private void copyNewAstMarkers(
    int start,
    int end,
    @NotNull ASTMarkersImpl<?> astMarkers
  ) {
    elementTypes.removeInRange(start, end + 1);
    elementTypes.insertAll(start, astMarkers.elementTypes.toArray(IElementType.EMPTY_ARRAY));
    packer.replace(start, end + 1, astMarkers.packer);
  }

  private void removeMarkersFromMaps(int start, int end) {
    IntRange.closed(start, end).forEach(index -> {
      var currentId = id(index);
      if (hasError(index)) {
        errorMessages.remove(currentId);
      }
    });
  }

  private void renumber(int insertedStart, int insertedEnd) {
    var renumberMap = new Int2IntOpenHashMap();
    IntStream.range(0, getSize()).forEach(index -> {
      // skip new elements
      if (index >= insertedStart && index < insertedEnd) return;
      var newId = renumberMap.computeIfAbsent(id(index), oldId -> {
        var currentId = nextId++;
        if (hasError(index)) {
          errorMessages.put(currentId, errorMessages.remove(oldId));
        }
        return currentId;
      });
      packer.setId(index, newId);
    });
  }

  void setChameleon(int lexemeIndex, @NotNull AtomicReference<T> chameleon) {
    chameleonsMap.put(lexemeIndex, chameleon);
  }

  void setMarkersCount(int i, int descCount) {
    packer.setMarkersCount(i, descCount);
  }

  void setLexemeInfo(int i, int lexemeCount, int relOffset) {
    packer.setLexemeInfo(i, lexemeCount, relOffset);
  }

  int pushBack() {
    var i = getSize();
    packer.pushBack();
    elementTypes.insert(i, TokenType.ERROR_ELEMENT);
    return i;
  }

  void setMarker(
    int index,
    int id,
    byte kind,
    boolean collapsed,
    @Nullable String errorMessage,
    @Nullable IElementType elementType
  ) {
    if (kind(index) != MarkerKind.Undone) throw new AssertionError();
    if (id + 1 > nextId) nextId = id + 1;
    packer.setInitialInfo(index, id, kind, collapsed, errorMessage != null);
    if (errorMessage != null) errorMessages.put(id, errorMessage);
    if (elementType != null) elementTypes.set(index, elementType);
  }

  @Override public @NotNull ASTMarkers<T> mutate(@NotNull Function<MutableContext<T>, Unit> mutator) {
    var ctx = new MutableContextImpl<T>(new ASTMarkersImpl<>(this));
    mutator.apply(ctx);
    return ctx.ast;
  }

  @Override public String toString() {
    var b = new StringBuilder();
    var depth = MutableIntValue.create(0);
    elementTypes.forEachIndexed((index, type) -> {
      if (kind(index) == MarkerKind.End) depth.decrement();
      ASTMarkers.repeat(depth.get(), i -> b.append(" "));
      b.append(type).append(" ").append(kind(index)).append(" ");
      b.append("e=").append(hasError(index)).append(" ");
      b.append("c=").append(collapsed(index)).append(" ");
      b.append("lo=").append(lexemeRelOffset(index)).append(" ");
      b.append("lc=").append(lexemeCount(index)).append(" ");
      b.append("mc=").append(markersCount(index)).append("\n");
      if (kind(index) == MarkerKind.Start) depth.increment();
    });
    b.append("{").append("\n");
    chameleonsMap.forEach((t, u) -> b.append(t).append(" -> ").append(u.get()).append("\n"));
    b.append("}").append("\n");
    return b.toString();
  }

  record LexemeInfo(int relOffset, int count) {}

  private static class Packer {
    private static final int KIND_MASK = 3;
    private static final int MAX_SHORT_ID_VALUE = Character.MAX_VALUE;
    private static final int MAX_SHORT_LEXEME_VALUE = Character.MAX_VALUE;
    private static final int MAX_SHORT_MARKERS_COUNT = Character.MAX_VALUE >>> 4;
    private static final int MAX_LONG_MARKERS_COUNT = Integer.MAX_VALUE >>> 3;

    // short mode: (kind(2) + collapsed(1) + hasError(1) + markersCount(12) + id(16), lexemeRelOffset(16) + lexemeCount(16) = 2
    // long mode: kind(2) + collapsed(1) + hasError(1) + markersCount(28), id(32), lexemeRelOffset(32), lexemeCount(32) = 4 ints
    private final MutableIntList ints;
    private boolean longMode = false;

    private boolean shortMode() {
      return !longMode;
    }

    public Packer() {
      ints = MutableIntArrayList.create(DEFAULT_CAPACITY);
    }

    private Packer(@NotNull Packer origin) {
      longMode = origin.longMode;
      ints = MutableIntArrayList.from(origin.ints);
    }

    private int index(int i) {
      return longMode ? i * 4 : i * 2;
    }

    byte kind(int i) {
      return (byte) (ints.get(index(i)) & KIND_MASK);
    }

    @NotNull LexemeInfo lexemeInfo(int i) {
      var index = index(i);
      if (longMode) return new LexemeInfo(ints.get(index + 2), ints.get(index + 3));
      else {
        var packed = ints.get(index + 1);
        return new LexemeInfo(packed & MAX_SHORT_LEXEME_VALUE, packed >>> 16);
      }
    }

    boolean collapsed(int i) {
      return (ints.get(index(i)) & 4) == 4;
    }

    boolean hasError(int i) {
      return (ints.get(index(i)) & 8) == 8;
    }

    int id(int i) {
      var index = index(i);
      return longMode ? ints.get(index + 1) : ints.get(index) >>> 16;
    }

    int markersCount(int i) {
      return (ints.get(index(i)) >>> 4) & (longMode ? MAX_LONG_MARKERS_COUNT : MAX_SHORT_MARKERS_COUNT);
    }

    int size() {
      return ints.size() / (longMode ? 4 : 2);
    }

    void setLexemeInfo(int i, int count, int relOffset) {
      if (shortMode() && (relOffset > MAX_SHORT_LEXEME_VALUE || count > MAX_SHORT_LEXEME_VALUE)) grow();
      var index = index(i);
      if (longMode) {
        ints.set(index + 2, relOffset);
        ints.set(index + 3, count);
      } else {
        ints.set(index + 1, relOffset | (count << 16));
      }
    }

    void setMarkersCount(int i, int count) {
      if (shortMode() && count > MAX_SHORT_MARKERS_COUNT) grow();
      ASTMarkers.check(count <= MAX_LONG_MARKERS_COUNT, "markers count $count is bigger than $MAX_LONG_MARKERS_COUNT");
      var index = index(i);
      ints.set(index, (ints.get(index) & ~(MAX_SHORT_MARKERS_COUNT << 4)) | (count << 4));
    }

    void setInitialInfo(int i, int id, byte kind, boolean collapsed, boolean hasErrors) {
      var collapsedInt = collapsed ? 4 : 0;
      var hasErrorInt = hasErrors ? 8 : 0;
      if (shortMode() && id > MAX_SHORT_ID_VALUE) grow();
      var index = index(i);
      ints.set(index, (int) kind + collapsedInt + hasErrorInt);
      setId(i, id);
    }

    void setId(int i, int id) {
      var index = index(i);
      if (longMode) ints.set(index + 1, id);
      else ints.set(index, (ints.get(index) & ((int) Character.MAX_VALUE)) | (id << 16));
    }

    void replace(int start, int end, @NotNull Packer packer) {
      if (longMode && !packer.longMode) packer.grow();
      if (!longMode && packer.longMode) grow();
      ints.removeInRange(index(start), index(end));
      ints.insertAll(index(start), packer.ints.toArray());
    }

    private void grow() {
      ASTMarkers.check(shortMode(), "Already in long mode");
      var end = size() - 1;
      ASTMarkers.repeat(ints.size(), i -> ints.append(0));
      longMode = true;
      ASTMarkers.downTo(end, 0, i -> {
        var firstInt = ints.get(i * 2);
        var secondInt = ints.get(i * 2 + 1);
        setInitialInfo(
          i,
          firstInt >>> 16,
          (byte) (firstInt & 3),
          (firstInt & 4) == 4,
          (firstInt & 8) == 8);
        setMarkersCount(i, (firstInt >>> 4) & MAX_SHORT_MARKERS_COUNT);
        setLexemeInfo(i, secondInt >>> 16, secondInt & MAX_SHORT_LEXEME_VALUE);
      });
    }

    void pushBack() {
      ASTMarkers.repeat(longMode ? 4 : 2, i -> ints.append(0));
    }

    @NotNull Packer copy() {
      return new Packer(this);
    }
  }

  record MutableContextImpl<T>(@NotNull ASTMarkersImpl<T> ast) implements ASTMarkers.MutableContext<T> {
    @Override public void substitute(int i, int lexemeIndex, @NotNull ASTMarkers<T> astMarkers) {
      ast.substituteImpl(astMarkers, i, lexemeIndex);
    }

    @Override public void changeChameleons(@NotNull List<Pair<Integer, AtomicReference<T>>> pairs) {
      ast.chameleonsMap.clear();
      pairs.forEach(kv -> ast.chameleonsMap.put(kv.first, kv.second));
    }

    @Override public void changeLexCount(int startMarker, int endMarker, int lexCount) {
      var relOffset = ast.lexemeRelOffset(startMarker);
      ast.setLexemeInfo(startMarker, lexCount, relOffset);
      ast.setLexemeInfo(endMarker, lexCount, relOffset);
    }

    @Override public void changeMarkerCount(int startMarker, int endMarker, int markerCount) {
      ast.setMarkersCount(startMarker, markerCount);
      ast.setMarkersCount(endMarker, markerCount);
    }
  }
}
