/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.devui;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatMessagePojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatResultPojo;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkus.arc.All;
import io.quarkus.arc.Arc;
import io.quarkus.logging.Log;
import jakarta.enterprise.context.control.ActivateRequestContext;
import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

@ActivateRequestContext
public class ChatJsonRPCService {
    private final ChatLanguageModel model;
    private final ChatMemoryProvider memoryProvider;
    private RetrievalAugmentor retrievalAugmentor;
    private final List<ToolSpecification> toolSpecifications;
    private final Map<String, ToolExecutor> toolExecutors;
    private final AtomicReference<ChatMemory> currentMemory = new AtomicReference();
    private final AtomicLong currentMemoryId = new AtomicLong();

    public ChatJsonRPCService(@All List<ChatLanguageModel> models, @All List<Supplier<RetrievalAugmentor>> retrievalAugmentorSuppliers, @All List<RetrievalAugmentor> retrievalAugmentors, ChatMemoryProvider memoryProvider, QuarkusToolExecutorFactory toolExecutorFactory) {
        this.model = models.get(0);
        this.retrievalAugmentor = null;
        for (Supplier<RetrievalAugmentor> supplier : retrievalAugmentorSuppliers) {
            this.retrievalAugmentor = supplier.get();
            if (this.retrievalAugmentor == null) continue;
            break;
        }
        if (this.retrievalAugmentor == null) {
            Iterator<Supplier<RetrievalAugmentor>> iterator = retrievalAugmentors.iterator();
            while (iterator.hasNext()) {
                RetrievalAugmentor augmentorFromCdi;
                this.retrievalAugmentor = augmentorFromCdi = (RetrievalAugmentor)iterator.next();
                if (this.retrievalAugmentor == null) continue;
                break;
            }
        }
        this.memoryProvider = memoryProvider;
        Map<String, List<ToolMethodCreateInfo>> toolsMetadata = ToolsRecorder.getMetadata();
        if (toolsMetadata != null) {
            this.toolExecutors = new HashMap<String, ToolExecutor>();
            this.toolSpecifications = new ArrayList<ToolSpecification>();
            for (Map.Entry<String, List<ToolMethodCreateInfo>> entry : toolsMetadata.entrySet()) {
                for (ToolMethodCreateInfo methodCreateInfo : entry.getValue()) {
                    Object objectWithTool = null;
                    try {
                        objectWithTool = Arc.container().select(Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()), new Annotation[0]).get();
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    QuarkusToolExecutor.Context executorContext = new QuarkusToolExecutor.Context(objectWithTool, methodCreateInfo.getInvokerClassName(), methodCreateInfo.getMethodName(), methodCreateInfo.getArgumentMapperClassName());
                    this.toolExecutors.put(methodCreateInfo.getToolSpecification().name(), toolExecutorFactory.create(executorContext));
                    this.toolSpecifications.add(methodCreateInfo.getToolSpecification());
                }
            }
        } else {
            this.toolSpecifications = List.of();
            this.toolExecutors = Map.of();
        }
    }

    public String reset(String systemMessage) {
        if (this.currentMemory.get() != null) {
            this.currentMemory.get().clear();
        }
        long memoryId = ThreadLocalRandom.current().nextLong();
        this.currentMemoryId.set(memoryId);
        ChatMemory memory = this.memoryProvider.get((Object)memoryId);
        this.currentMemory.set(memory);
        if (systemMessage != null && !systemMessage.isEmpty()) {
            memory.add((ChatMessage)new SystemMessage(systemMessage));
        }
        return "OK";
    }

    public ChatResultPojo chat(String message, boolean ragEnabled) {
        ChatMemory memory = this.currentMemory.get();
        if (memory == null) {
            this.reset("");
            memory = this.currentMemory.get();
        }
        List chatMemoryBackup = memory.messages();
        try {
            if (this.retrievalAugmentor != null && ragEnabled) {
                UserMessage userMessage = UserMessage.from((String)message);
                Metadata metadata = Metadata.from((UserMessage)userMessage, (Object)this.currentMemoryId.get(), (List)memory.messages());
                memory.add((ChatMessage)this.retrievalAugmentor.augment(userMessage, metadata));
            } else {
                memory.add((ChatMessage)new UserMessage(message));
            }
            if (this.toolSpecifications.isEmpty()) {
                modelResponse = this.model.generate(memory.messages());
                memory.add((ChatMessage)modelResponse.content());
            } else {
                modelResponse = this.executeWithTools(memory);
            }
            List<ChatMessagePojo> response = ChatMessagePojo.listFromMemory(memory);
            return new ChatResultPojo(response, null);
        }
        catch (Throwable t) {
            memory.clear();
            chatMemoryBackup.forEach(arg_0 -> ((ChatMemory)memory).add(arg_0));
            Log.warn((Object)t);
            return new ChatResultPojo(null, t.getMessage());
        }
    }

    public Response<AiMessage> executeWithTools(ChatMemory memory) {
        int MAX_SEQUENTIAL_TOOL_EXECUTIONS;
        Response response = this.model.generate(memory.messages(), this.toolSpecifications);
        int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20;
        while (true) {
            if (executionsLeft-- == 0) {
                throw new RuntimeException("Something is wrong, exceeded " + MAX_SEQUENTIAL_TOOL_EXECUTIONS + " sequential tool executions");
            }
            AiMessage aiMessage = (AiMessage)response.content();
            memory.add((ChatMessage)aiMessage);
            if (!aiMessage.hasToolExecutionRequests()) break;
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                ToolExecutor toolExecutor = this.toolExecutors.get(toolExecutionRequest.name());
                String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, (Object)this.currentMemoryId.get());
                ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecutionRequest, (String)toolExecutionResult);
                memory.add((ChatMessage)toolExecutionResultMessage);
            }
            response = this.model.generate(memory.messages(), this.toolSpecifications);
        }
        return Response.from((Object)((AiMessage)response.content()), (TokenUsage)new TokenUsage(), (FinishReason)response.finishReason());
    }
}

