package com.intellij.psi.builder;

import org.aya.ij.AyaModified;
import com.intellij.openapi.util.Pair;
import com.intellij.psi.tree.IElementType;
import kala.tuple.Unit;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.IntConsumer;

/** rewrite of ASTMarkers.kt */
@AyaModified
public interface ASTMarkers<T> {
  int getSize();
  byte kind(int i);

  @Nullable String errorMessage(int i);

  int lexemeCount(int i);
  int lexemeRelOffset(int i);

  boolean collapsed(int i);

  int markersCount(int i);

  @NotNull IElementType elementType(int i);
  @NotNull AtomicReference<T> chameleonAt(int lexemeIndex);

  @NotNull List<Pair<Integer, AtomicReference<T>>> chameleons();

  @NotNull ASTMarkers<T> mutate(@NotNull Function<MutableContext<T>, Unit> mutator);

  interface MutableContext<T> {
    void substitute(int i, int lexemeIndex, @NotNull ASTMarkers<T> astMarkers);
    void changeChameleons(@NotNull List<Pair<Integer, AtomicReference<T>>> pairs);
    void changeLexCount(int startMarker, int endMarker, int lexCount);
    void changeMarkerCount(int startMarker, int endMarker, int markerCount);
  }

  class MarkerKind {
    public static final byte Undone = 0;
    public static final byte Start = 1;
    public static final byte End = 2;
    public static final byte Error = 3;
  }

  default int prevSibling(int markerIndex) {
    if (markerIndex == 0) return -1;
    var prevMarkerIndex = markerIndex - 1;
    return switch (kind(prevMarkerIndex)) {
      case MarkerKind.Start -> -1;
      case MarkerKind.End -> prevMarkerIndex - markersCount(prevMarkerIndex);
      case MarkerKind.Error -> prevMarkerIndex;
      default -> error("no else");
    };
  }

  default int nextSibling(int markerIndex) {
    return switch (kind(markerIndex)) {
      case MarkerKind.Start -> {
        var endIndex = markerIndex + markersCount(markerIndex);
        if (endIndex == getSize() - 1) yield -1;

        var nextToEndIndex = endIndex + 1;
        yield switch (kind(nextToEndIndex)) {
          case MarkerKind.Start, MarkerKind.Error -> nextToEndIndex;
          case MarkerKind.End -> -1;
          default -> error("no else");
        };
      }

      case MarkerKind.Error -> {
        var nextKind = kind(markerIndex + 1);
        if (nextKind == MarkerKind.End) yield -1;
        else yield markerIndex + 1;
      }

      case MarkerKind.End -> error("should never be at the end");
      default -> error("no else");
    };
  }

  default int lastChild(int markerIndex) {
    return switch (kind(markerIndex)) {
      case MarkerKind.End -> error("never at end");
      case MarkerKind.Error -> -1;
      case MarkerKind.Start -> {
        var prevToEndIndex = markerIndex + markersCount(markerIndex) - 1;
        yield switch (kind(prevToEndIndex)) {
          case MarkerKind.Start -> -1;
          case MarkerKind.End -> prevToEndIndex - markersCount(prevToEndIndex);
          case MarkerKind.Error -> prevToEndIndex;
          default -> error("no else");
        };
      }
      default -> error("no else");
    };
  }

  default int firstChild(int markerIndex) {
    require(markerIndex < getSize() - 1, "at least there is an end");
    var nextMarkerIndex = markerIndex + 1;
    return switch (kind(nextMarkerIndex)) {
      case MarkerKind.Start, MarkerKind.Error -> nextMarkerIndex;
      case MarkerKind.End -> -1;
      default -> error("no else");
    };
  }

  /** mimic require() function in kotlin */
  static void require(boolean cond, @NotNull String msg) {
    if (!cond) throw new AssertionError(msg);
  }

  /** mimic check() function in kotlin */
  static void check(boolean cond, @NotNull String msg) {
    if (!cond) throw new AssertionError(msg);
  }

  /** mimic repeat() function in kotlin */
  static void repeat(int times, @NotNull IntConsumer consumer) {
    for (var i = 0; i < times; i++) consumer.accept(i);
  }

  /** mimic downTo() function in kotlin */
  static void downTo(int upper, int lower, @NotNull IntConsumer consumer) {
    for (var i = upper; i >= lower; i--) consumer.accept(i);
  }

  /** mimic error() function in kotlin */
  static <T> T error(@NotNull String error) {
    throw new IllegalStateException(error);
  }
}
