/*
 * Decompiled with CFR 0.152.
 */
package org.tio.http.mcp.server;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.http.common.HttpRequest;
import org.tio.http.common.HttpResponse;
import org.tio.http.common.HttpResponseStatus;
import org.tio.http.jsonrpc.JsonRpcMessage;
import org.tio.http.jsonrpc.JsonRpcNotification;
import org.tio.http.jsonrpc.JsonRpcRequest;
import org.tio.http.jsonrpc.JsonRpcResponse;
import org.tio.http.mcp.schema.McpCallToolRequest;
import org.tio.http.mcp.schema.McpCallToolResult;
import org.tio.http.mcp.schema.McpImplementation;
import org.tio.http.mcp.schema.McpInitializeRequest;
import org.tio.http.mcp.schema.McpInitializeResult;
import org.tio.http.mcp.schema.McpListToolsResult;
import org.tio.http.mcp.schema.McpRoot;
import org.tio.http.mcp.schema.McpServerCapabilities;
import org.tio.http.mcp.schema.McpTool;
import org.tio.http.mcp.server.McpPromptSpecification;
import org.tio.http.mcp.server.McpResourceSpecification;
import org.tio.http.mcp.server.McpResourceTemplateSpecification;
import org.tio.http.mcp.server.McpServerSession;
import org.tio.http.mcp.server.McpToolSpecification;
import org.tio.http.sse.SseEmitter;
import org.tio.utils.hutool.StrUtil;
import org.tio.utils.json.JsonUtil;

public class McpServer {
    private static final Logger log = LoggerFactory.getLogger(McpServer.class);
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    public static final String DEFAULT_MESSAGE_ENDPOINT = "/sse/message";
    private static final McpImplementation DEFAULT_SERVER_INFO = new McpImplementation("mcp-server", "1.0.0");
    private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap();
    private final String sseEndpoint;
    private final String messageEndpoint;
    private McpImplementation serverInfo = DEFAULT_SERVER_INFO;
    private McpServerCapabilities serverCapabilities;
    private final List<McpToolSpecification> tools = new ArrayList<McpToolSpecification>();
    private final Map<String, McpResourceSpecification> resources = new HashMap<String, McpResourceSpecification>();
    private final Map<String, McpResourceTemplateSpecification> resourceTemplates = new HashMap<String, McpResourceTemplateSpecification>();
    private final Map<String, McpPromptSpecification> prompts = new HashMap<String, McpPromptSpecification>();
    private final List<BiConsumer<McpServerSession, List<McpRoot>>> rootsChangeHandlers = new ArrayList<BiConsumer<McpServerSession, List<McpRoot>>>();

    public McpServer() {
        this(DEFAULT_SSE_ENDPOINT, DEFAULT_MESSAGE_ENDPOINT);
    }

    public McpServer(String sseEndpoint, String messageEndpoint) {
        this.sseEndpoint = StrUtil.isBlank((CharSequence)sseEndpoint) ? DEFAULT_SSE_ENDPOINT : sseEndpoint;
        this.messageEndpoint = StrUtil.isBlank((CharSequence)messageEndpoint) ? DEFAULT_MESSAGE_ENDPOINT : messageEndpoint;
    }

    public McpServer serverInfo(McpImplementation serverInfo) {
        Objects.requireNonNull(serverInfo, "Server info must not be null");
        this.serverInfo = serverInfo;
        return this;
    }

    public McpServer serverInfo(String name, String version) {
        if (StrUtil.isBlank((CharSequence)name)) {
            throw new IllegalArgumentException("Server info name must not be blank");
        }
        if (StrUtil.isBlank((CharSequence)version)) {
            throw new IllegalArgumentException("Server info version must not be blank");
        }
        this.serverInfo = new McpImplementation(name, version);
        return this;
    }

    public McpServer capabilities(McpServerCapabilities serverCapabilities) {
        Objects.requireNonNull(serverCapabilities, "Server capabilities must not be null");
        this.serverCapabilities = serverCapabilities;
        return this;
    }

    public McpServer tool(McpTool tool, BiFunction<McpServerSession, Map<String, Object>, McpCallToolResult> handler) {
        Objects.requireNonNull(tool, "Tool must not be null");
        Objects.requireNonNull(handler, "Handler must not be null");
        this.tools.add(new McpToolSpecification(tool, handler));
        return this;
    }

    public McpServer tools(List<McpToolSpecification> toolSpecifications) {
        Objects.requireNonNull(toolSpecifications, "Tool handlers list must not be null");
        this.tools.addAll(toolSpecifications);
        return this;
    }

    public McpServer tools(McpToolSpecification ... toolSpecifications) {
        Objects.requireNonNull(toolSpecifications, "Tool handlers list must not be null");
        this.tools.addAll(Arrays.asList(toolSpecifications));
        return this;
    }

    public McpServer resources(Map<String, McpResourceSpecification> resourceSpecifications) {
        Objects.requireNonNull(resourceSpecifications, "Resource handlers map must not be null");
        this.resources.putAll(resourceSpecifications);
        return this;
    }

    public McpServer resources(List<McpResourceSpecification> resourceSpecifications) {
        Objects.requireNonNull(resourceSpecifications, "Resource handlers list must not be null");
        for (McpResourceSpecification resource : resourceSpecifications) {
            this.resources.put(resource.getResource().getUri(), resource);
        }
        return this;
    }

    public McpServer resources(McpResourceSpecification ... resourceSpecifications) {
        Objects.requireNonNull(resourceSpecifications, "Resource handlers list must not be null");
        for (McpResourceSpecification resource : resourceSpecifications) {
            this.resources.put(resource.getResource().getUri(), resource);
        }
        return this;
    }

    public McpServer resourceTemplates(List<McpResourceTemplateSpecification> resourceTemplates) {
        Objects.requireNonNull(resourceTemplates, "Resource templates must not be null");
        for (McpResourceTemplateSpecification resource : resourceTemplates) {
            this.resourceTemplates.put(resource.getResource().getUriTemplate(), resource);
        }
        return this;
    }

    public McpServer resourceTemplates(McpResourceTemplateSpecification ... resourceTemplates) {
        Objects.requireNonNull(resourceTemplates, "Resource templates must not be null");
        for (McpResourceTemplateSpecification resource : resourceTemplates) {
            this.resourceTemplates.put(resource.getResource().getUriTemplate(), resource);
        }
        return this;
    }

    public McpServer prompts(Map<String, McpPromptSpecification> prompts) {
        Objects.requireNonNull(prompts, "Prompts map must not be null");
        this.prompts.putAll(prompts);
        return this;
    }

    public McpServer prompts(List<McpPromptSpecification> prompts) {
        Objects.requireNonNull(prompts, "Prompts list must not be null");
        for (McpPromptSpecification prompt : prompts) {
            this.prompts.put(prompt.getPrompt().getName(), prompt);
        }
        return this;
    }

    public McpServer prompts(McpPromptSpecification ... prompts) {
        Objects.requireNonNull(prompts, "Prompts list must not be null");
        for (McpPromptSpecification prompt : prompts) {
            this.prompts.put(prompt.getPrompt().getName(), prompt);
        }
        return this;
    }

    public McpServer rootsChangeHandler(BiConsumer<McpServerSession, List<McpRoot>> handler) {
        Objects.requireNonNull(handler, "Consumer must not be null");
        this.rootsChangeHandlers.add(handler);
        return this;
    }

    public McpServer rootsChangeHandlers(List<BiConsumer<McpServerSession, List<McpRoot>>> handlers) {
        Objects.requireNonNull(handlers, "Handlers list must not be null");
        this.rootsChangeHandlers.addAll(handlers);
        return this;
    }

    @SafeVarargs
    public final McpServer rootsChangeHandlers(BiConsumer<McpServerSession, List<McpRoot>> ... handlers) {
        Objects.requireNonNull(handlers, "Handlers list must not be null");
        return this.rootsChangeHandlers(Arrays.asList(handlers));
    }

    public HttpResponse sseEndpoint(HttpRequest request) {
        HttpResponse httpResponse = new HttpResponse(request);
        SseEmitter emitter = SseEmitter.getEmitter(request, httpResponse);
        httpResponse.setPacketListener((context, packet, isSentSuccess) -> {
            if (isSentSuccess) {
                String sessionId = StrUtil.getNanoId();
                this.sessions.put(sessionId, new McpServerSession(sessionId, emitter));
                emitter.send(ENDPOINT_EVENT_TYPE, this.messageEndpoint + "?sessionId=" + sessionId);
            }
        });
        return httpResponse;
    }

    public HttpResponse sseMessageEndpoint(HttpRequest request) {
        String sessionId = request.getParam("sessionId");
        HttpResponse response = new HttpResponse(request);
        if (StrUtil.isBlank((CharSequence)sessionId)) {
            response.setStatus(HttpResponseStatus.C404);
            response.setBody("Session ID missing in message endpoint".getBytes());
            return response;
        }
        McpServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            response.setStatus(HttpResponseStatus.C404);
            response.setBody("Session is null".getBytes());
            log.error("Session is null sessionId:{}", (Object)sessionId);
            return response;
        }
        JsonRpcMessage jsonRpcMessage = McpServer.deserializeJsonRpcMessage(request.getBody());
        if (jsonRpcMessage instanceof JsonRpcRequest) {
            JsonRpcResponse rpcResponse = this.handleIncomingRequest(session, (JsonRpcRequest)jsonRpcMessage);
            session.sendMessage(rpcResponse);
        } else if (jsonRpcMessage instanceof JsonRpcNotification) {
            JsonRpcNotification notification = (JsonRpcNotification)jsonRpcMessage;
            log.info("JsonRpcNotification:{}", (Object)notification);
        }
        return response;
    }

    public void sendHeartbeat() {
        for (McpServerSession session : this.sessions.values()) {
            session.sendHeartbeat();
        }
    }

    public String getMessageEndpoint() {
        return this.messageEndpoint;
    }

    public String getSseEndpoint() {
        return this.sseEndpoint;
    }

    private JsonRpcResponse handleIncomingRequest(McpServerSession session, JsonRpcRequest request) {
        String method = request.getMethod();
        if ("initialize".equals(method)) {
            McpInitializeRequest initializeRequest = (McpInitializeRequest)JsonUtil.convertValue((Object)request.getParams(), McpInitializeRequest.class);
            McpInitializeResult result = new McpInitializeResult();
            result.setProtocolVersion(initializeRequest.getProtocolVersion());
            result.setCapabilities(this.serverCapabilities);
            result.setServerInfo(this.serverInfo);
            JsonRpcResponse jsonRpcResponse = new JsonRpcResponse();
            jsonRpcResponse.setJsonrpc("2.0");
            jsonRpcResponse.setId(request.getId());
            jsonRpcResponse.setResult(result);
            return jsonRpcResponse;
        }
        if ("ping".equals(method)) {
            JsonRpcResponse jsonRpcResponse = new JsonRpcResponse();
            jsonRpcResponse.setJsonrpc("2.0");
            jsonRpcResponse.setId(request.getId());
            jsonRpcResponse.setResult(Collections.emptyMap());
            return jsonRpcResponse;
        }
        if ("tools/list".equals(method)) {
            JsonRpcResponse jsonRpcResponse = new JsonRpcResponse();
            jsonRpcResponse.setJsonrpc("2.0");
            jsonRpcResponse.setId(request.getId());
            McpListToolsResult toolsResult = new McpListToolsResult();
            ArrayList<McpTool> tools = new ArrayList<McpTool>();
            for (McpToolSpecification toolSpecification : this.tools) {
                tools.add(toolSpecification.getTool());
            }
            toolsResult.setTools(tools);
            jsonRpcResponse.setResult(toolsResult);
            return jsonRpcResponse;
        }
        if ("tools/call".equals(method)) {
            JsonRpcResponse jsonRpcResponse = new JsonRpcResponse();
            jsonRpcResponse.setJsonrpc("2.0");
            jsonRpcResponse.setId(request.getId());
            McpCallToolRequest callToolRequest = (McpCallToolRequest)JsonUtil.convertValue((Object)request.getParams(), McpCallToolRequest.class);
            String name = callToolRequest.getName();
            McpCallToolResult toolResult = null;
            for (McpToolSpecification toolSpecification : this.tools) {
                McpTool tool = toolSpecification.getTool();
                if (!tool.getName().equals(name)) continue;
                Map<String, Object> toolArguments = McpServer.getCallToolArguments(callToolRequest.getArguments());
                toolResult = toolSpecification.getCall().apply(session, toolArguments);
                break;
            }
            if (toolResult == null) {
                throw new IllegalArgumentException("Cannot find tool with name " + name);
            }
            jsonRpcResponse.setResult(toolResult);
            return jsonRpcResponse;
        }
        return null;
    }

    private static JsonRpcMessage deserializeJsonRpcMessage(byte[] requestBody) {
        Map map = (Map)JsonUtil.readValue((byte[])requestBody, Map.class);
        String jsonText = new String(requestBody);
        log.debug("Received JSON message: {}", (Object)jsonText);
        if (map.containsKey("method") && map.containsKey("id")) {
            return (JsonRpcMessage)JsonUtil.convertValue((Object)map, JsonRpcRequest.class);
        }
        if (map.containsKey("method") && !map.containsKey("id")) {
            return (JsonRpcMessage)JsonUtil.convertValue((Object)map, JsonRpcNotification.class);
        }
        if (map.containsKey("result") || map.containsKey("error")) {
            return (JsonRpcMessage)JsonUtil.convertValue((Object)map, JsonRpcResponse.class);
        }
        throw new IllegalArgumentException("Cannot deserialize JsonRpcMessage: " + jsonText);
    }

    private static Map<String, Object> getCallToolArguments(Object arguments) {
        if (arguments == null) {
            return null;
        }
        if (arguments instanceof Map) {
            return (Map)arguments;
        }
        if (arguments instanceof String && StrUtil.isBlank((CharSequence)((String)arguments))) {
            return null;
        }
        return (Map)JsonUtil.convertValue((Object)arguments, Map.class);
    }
}

