package dev.speakeasyapi.sdk.utils;

import java.time.Instant;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.google.common.util.concurrent.MoreExecutors;
import dev.speakeasyapi.sdk.client.ISpeakeasyClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PatchMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;

import dev.speakeasyapi.sdk.SpeakeasyRequestResponseHandler;
import dev.speakeasyapi.sdk.client.SpeakeasyClient;

public class SpeakeasyInterceptor implements HandlerInterceptor {
    public static final String ControllerKey = "speakeasyMiddlewareController";
    public static final String StartTimeKey = "speakeasyStartTime";
    private Executor pool;
    private final ISpeakeasyClient client;
    private final String apiKey;
    private final String apiID;
    private final String versionID;
    private String serverUrl = "grpc.prod.speakeasyapi.dev:443";
    private boolean secureGrpc = true;
    private Logger logger = LoggerFactory.getLogger(SpeakeasyInterceptor.class);

    public SpeakeasyInterceptor(String apiKey, String apiID, String versionID, ISpeakeasyClient client) {
        this.pool = Executors.newCachedThreadPool();
        this.apiKey = apiKey;
        this.apiID = apiID;
        this.versionID = versionID;

        String serverURL = System.getenv("SPEAKEASY_SERVER_URL");
        if (serverURL != null) {
            this.serverUrl = serverURL;
        }

        if ("false".equals(System.getenv("SPEAKEASY_SERVER_SECURE"))) {
            this.secureGrpc = false;
        }

        boolean disableIngest = false;
        if ("true".equals(System.getenv("SPEAKEASY_TEST_MODE"))) {
            disableIngest = true;
            this.pool = MoreExecutors.directExecutor();
        }
        // If no client was passed in, create a default client.
        this.client = client != null ? client : new SpeakeasyClient(apiKey, apiID, versionID, this.serverUrl, this.secureGrpc, disableIngest);
    }

    public SpeakeasyInterceptor(String apiKey, String apiID, String versionID) {
        this(apiKey, apiID, versionID, null);
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
        request.setAttribute(ControllerKey, new SpeakeasyMiddlewareController());
        request.setAttribute(StartTimeKey, Instant.now());

        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest req, HttpServletResponse res, Object handler, Exception ex) {
        SpeakeasyMiddlewareController controller = (SpeakeasyMiddlewareController) req.getAttribute(ControllerKey);

        String pathHint;
        if (StringUtils.hasText(controller.getPathHint())) {
            pathHint = controller.getPathHint();
        } else {
            pathHint = getPathHint((HandlerMethod) handler);
        }
        Instant startTime = (Instant) req.getAttribute(StartTimeKey);
        RequestResponseCaptureWatcher watcher = (RequestResponseCaptureWatcher) req.getAttribute(SpeakeasyRequestWrapper.speakeasyRequestResponseWatcherAttribute);

        pool.execute(new SpeakeasyRequestResponseHandler(this.client, this.logger,
                req, res, watcher, startTime, Instant.now(), pathHint, controller.getCustomerID()));
    }

    private static String getPathHint(HandlerMethod hm) {
        String controllerPath = "";
        String pathHint = "";

        if (hm.getBean().getClass().isAnnotationPresent(RequestMapping.class)) {
            RequestMapping ctrlMapping = hm.getBean().getClass().getAnnotation(RequestMapping.class);
            if (ctrlMapping != null && ctrlMapping.value() != null && ctrlMapping.value().length > 0) {
                controllerPath = ctrlMapping.value()[0];

                if (pathHint.endsWith("/")) {
                    controllerPath = pathHint.substring(0, pathHint.length() - 1);
                }

                if (!controllerPath.startsWith("/")) {
                    controllerPath = "/" + controllerPath;
                }
            }
        }

        String methodPath = "";

        if (hm.hasMethodAnnotation(RequestMapping.class)) {
            RequestMapping mapping = hm.getMethodAnnotation(RequestMapping.class);
            if (mapping != null && mapping.value() != null && mapping.value().length > 0) {
                methodPath = mapping.value()[0];
            }
        } else if (hm.hasMethodAnnotation(GetMapping.class)) {
            GetMapping mapping = hm.getMethodAnnotation(GetMapping.class);
            if (mapping != null && mapping.value() != null && mapping.value().length > 0) {
                methodPath = mapping.value()[0];
            }
        } else if (hm.hasMethodAnnotation(PostMapping.class)) {
            PostMapping mapping = hm.getMethodAnnotation(PostMapping.class);
            if (mapping != null && mapping.value() != null && mapping.value().length > 0) {
                methodPath = mapping.value()[0];
            }
        } else if (hm.hasMethodAnnotation(PutMapping.class)) {
            PutMapping mapping = hm.getMethodAnnotation(PutMapping.class);
            if (mapping != null && mapping.value() != null && mapping.value().length > 0) {
                methodPath = mapping.value()[0];
            }
        } else if (hm.hasMethodAnnotation(DeleteMapping.class)) {
            DeleteMapping mapping = hm.getMethodAnnotation(DeleteMapping.class);
            if (mapping != null && mapping.value() != null && mapping.value().length > 0) {
                methodPath = mapping.value()[0];
            }
        } else if (hm.hasMethodAnnotation(PatchMapping.class)) {
            PatchMapping mapping = hm.getMethodAnnotation(PatchMapping.class);
            if (mapping != null && mapping.value() != null && mapping.value().length > 0) {
                methodPath = mapping.value()[0];
            }
        }

        if (!methodPath.startsWith("/")) {
            methodPath = "/" + methodPath;
        }

        return controllerPath + methodPath;
    }
}
