package org.javacs.lsp;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.javacs.lsp.adapters.EnumTypeAdapter;
import org.jetbrains.annotations.Nullable;

import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.logging.Level;
import java.util.logging.Logger;

public class LSP {
  private static final String CONTENT_LENGTH = "Content-Length:";
  private static final Gson gson = new GsonBuilder()
    .registerTypeAdapterFactory(new EnumTypeAdapter.Factory())
    .create();

  static String nextRequest(BufferedReader client) throws IOException {
    int contentLength = -1;
    while (true) {
      var line = client.readLine();
      if (line == null) throw new IOException("client closed");
      if (line.isEmpty()) return nextMessage(client, contentLength);
      if (line.toLowerCase().startsWith(CONTENT_LENGTH.toLowerCase())) try {
        contentLength = Integer.parseInt(line.substring(CONTENT_LENGTH.length()).trim());
      } catch (NumberFormatException ignored) {
        LOG.log(Level.SEVERE, "Unable to parse content-length header: " + line);
      }
    }
  }

  static String nextMessage(BufferedReader client, int length) throws IOException {
    if (length < 0) throw new IOException("Unexpected length: " + length);
    var buffer = new char[4096];
    var builder = new StringBuilder();
    int remaining = length;
    while (remaining > 0) {
      int needs = Math.min(remaining, buffer.length);
      int read = client.read(buffer, 0, needs);
      remaining -= read;
      builder.append(buffer, 0, read);
    }
    // Eat whitespaces
    // "Have observed problems with extra \r\n sequences from VSCode;"
    int skipped = 0;
    for (; skipped < length; skipped++) {
      var ch = builder.charAt(0);
      if (!Character.isWhitespace(ch))
        break;
      builder.deleteCharAt(0);
    }
    // re-read skipped chars
    int read = client.read(buffer, 0, skipped);
    builder.append(buffer, 0, read);

    if (read != skipped) throw new IOException(
      "Cannot re-read chars for skipped whitespaces, expected " + skipped + ", but read " + read);
    return builder.toString();
  }

  static Message parseMessage(String token) {
    return gson.fromJson(token, Message.class);
  }

  private static final Charset UTF_8 = StandardCharsets.UTF_8;

  private static void writeClient(OutputStream client, String messageText) {
    var messageBytes = messageText.getBytes(UTF_8);
    var headerText = String.format("%s %d\r\n\r\n", CONTENT_LENGTH, messageBytes.length);
    var headerBytes = headerText.getBytes(UTF_8);
    try {
      client.write(headerBytes);
      client.write(messageBytes);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  static String toJson(Object message) {
    return gson.toJson(message);
  }

  static void respond(OutputStream client, int requestId, Object params) {
    if (params instanceof ResponseError error) {
      throw new RuntimeException("Errors should be sent using LSP.error(...), " + error.message);
    }
    if (params instanceof Optional<?> option) {
      params = option.orElse(null);
    }
    var jsonText = toJson(params);
    var messageText = String.format("{\"jsonrpc\":\"2.0\",\"id\":%d,\"result\":%s}", requestId, jsonText);
    writeClient(client, messageText);
  }

  static void error(OutputStream client, int requestId, ResponseError error) {
    var jsonText = toJson(error);
    var messageText = String.format("{\"jsonrpc\":\"2.0\",\"id\":%d,\"error\":%s}", requestId, jsonText);
    writeClient(client, messageText);
  }

  private record ClientHandler(OutputStream client) implements InvocationHandler {
    private void notifyClient(String method, Object params) {
      if (params instanceof Optional<?> option) {
        params = option.orElse(null);
      }
      var jsonText = toJson(params);
      var messageText = String.format("{\"jsonrpc\":\"2.0\",\"method\":\"%s\",\"params\":%s}", method, jsonText);
      writeClient(client, messageText);
    }

    @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
      // TODO: server-to-client request?
      var notification = method.getAnnotation(LspNotification.class);
      // this is not a JsonRPC method, let it go.
      if (notification == null) {
        if (method.isDefault()) return InvocationHandler.invokeDefault(proxy, method, args);
        LOG.log(Level.WARNING, "Language client method " + method.getName() + " is neither a JsonRPC method nor a default method");
        return null;
      }
      var rpcMethod = notification.value();
      var params = method.getParameterCount() > 0 ? args[0] : null;
      notifyClient(rpcMethod, params);
      return null;
    }
  }

  private static <T extends LanguageClient> T createClientProxy(Class<T> clientClass, OutputStream client) {
    //noinspection unchecked
    return (T) Proxy.newProxyInstance(
      clientClass.getClassLoader(),
      new Class[]{clientClass},
      new ClientHandler(client));
  }

  public static <T extends LanguageClient> void connect(
    Class<T> clientClass,
    Function<T, LanguageServer> serverFactory,
    InputStream receive, OutputStream send
  ) {
    var reader = new BufferedReader(new InputStreamReader(receive));
    var client = createClientProxy(clientClass, send);
    var server = serverFactory.apply(client);
    var pending = new ArrayBlockingQueue<Message>(10);
    var handlerCache = new HashMap<String, Handler>();

    // Read messages and process cancellations on a separate thread
    var readerThread = new Thread(new Runnable() {
      void peek(Message message) {
        if (message.method.equals("$/cancelRequest")) {
          var params = gson.fromJson(message.params, CancelParams.class);
          var removed = pending.removeIf(r -> r.id != null && r.id.equals(params.id));
          if (removed) LOG.info(String.format("Cancelled request %d, which had not yet started", params.id));
          else LOG.info(String.format("Cannot cancel request %d because it has already started", params.id));
        }
      }

      private boolean kill() {
        LOG.info("Read stream has been closed, putting kill message onto queue...");
        try {
          pending.put(Message.EOF);
          return true;
        } catch (Exception e) {
          LOG.log(Level.SEVERE, "Failed to put kill message onto queue, will try again...", e);
          return false;
        }
      }

      @Override public void run() {
        LOG.info("Placing incoming messages on queue...");

        while (true) {
          try {
            var token = nextRequest(reader);
            var message = parseMessage(token);
            peek(message);
            pending.put(message);
          } catch (IOException e) {
            LOG.log(Level.SEVERE, e.getMessage(), e);
            if (kill()) return;
          } catch (Exception e) {
            LOG.log(Level.SEVERE, e.getMessage(), e);
          }
        }
      }
    });
    readerThread.setName("LSP Reader Thread");
    readerThread.setDaemon(true);
    readerThread.start();

    // Process messages on main thread
    LOG.info("Reading messages from queue...");
    var hasAsyncWork = false;
    processMessages:
    while (true) {
      Message r;
      try {
        // Take a break periodically
        r = pending.poll(200, TimeUnit.MILLISECONDS);
      } catch (Exception e) {
        LOG.log(Level.SEVERE, e.getMessage(), e);
        continue;
      }
      // If receive has been closed, exit
      if (r == Message.EOF) {
        LOG.warning("Stream from client has been closed, exiting...");
        break processMessages;
      }
      // If poll(..) failed, loop again
      if (r == null) {
        if (hasAsyncWork) {
          server.doAsyncWork();
          hasAsyncWork = false;
        }
        continue;
      }
      // Otherwise, process the new message
      hasAsyncWork = true;
      try {
        if (r.method.equals("$/cancelRequest")) continue;
        // ^ handled in peek(..)

        var hdl = handlerCache.get(r.method);
        if (hdl == null) {
          hdl = findHandler(server.getClass(), r.method);
          if (hdl != null) handlerCache.put(r.method, hdl);
        }
        if (hdl == null) {
          LOG.warning(String.format("Don't know what to do with method `%s`", r.method));
          continue;
        }

        var param = hdl.paramType == null ? null : gson.fromJson(r.params, hdl.paramType);
        var result = param == null ? hdl.method.invoke(server) : hdl.method.invoke(server, param);
        if (hdl.isRequest) respond(send, r.id, result);

      } catch (Exception e) {
        LOG.log(Level.SEVERE, e.getMessage(), e);

        if (r.id != null) {
          boolean handled = false;

          // Handling the exception thrown by the handler
          if (e instanceof InvocationTargetException ex) {
            var cause = ex.getTargetException();

            if (cause instanceof ResponseErrorException respErr) {
              error(send, r.id, new ResponseError(respErr.errorCode(), respErr.getMessage(), respErr.data()));
              handled = true;
            }
          }

          if (! handled)
            error(send, r.id, new ResponseError(ErrorCodes.InternalError, e.getMessage(), null));
        }
      }
    }
  }

  private static Handler findHandler(Class<?> server, String rpcMethod) {
    if (server == null) return null;
    var handler = findHandler(server.getSuperclass(), rpcMethod);
    if (handler != null) return handler;
    for (var interf : server.getInterfaces()) {
      handler = findHandler(interf, rpcMethod);
      if (handler != null) return handler;
    }
    for (var method : server.getDeclaredMethods()) {
      handler = checkMethod(rpcMethod, method);
      if (handler != null) return handler;
    }
    return null;
  }

  @Nullable private static Handler checkMethod(String rpcMethod, Method method) {
    var req = method.getAnnotation(LspRequest.class);
    var not = method.getAnnotation(LspNotification.class);
    if (req == null && not == null) return null;
    if (req != null && not != null) {
      LOG.log(Level.WARNING, "Method " + method.getName() + " tries to handle both notification and request");
      return null;
    }
    var paramTypes = method.getParameterTypes();
    var param = paramTypes.length > 0 ? paramTypes[0] : null;
    if (req != null && req.value().equals(rpcMethod)) return new Handler(method, param, true);
    if (not != null && not.value().equals(rpcMethod)) return new Handler(method, param, false);
    return null;
  }

  private record Handler(Method method, Class<?> paramType, boolean isRequest) {}

  private static final Logger LOG = Logger.getLogger("main");
}
