/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.aiservice;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.ServiceOutputParser;
import dev.langchain4j.service.TokenStream;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ManagedContext;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import org.jboss.logging.Logger;

public class AiServiceMethodImplementationSupport {
    private static final Logger log = Logger.getLogger(AiServiceMethodImplementationSupport.class);
    private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;

    public Object implement(Input input) {
        QuarkusAiServiceContext context = input.context;
        AiServiceMethodCreateInfo createInfo = input.createInfo;
        Object[] methodArgs = input.methodArgs;
        AuditService auditService = context.auditService;
        Audit audit = null;
        if (auditService != null) {
            audit = auditService.create(new Audit.CreateInfo(createInfo.getInterfaceName(), createInfo.getMethodName(), methodArgs, createInfo.getMemoryIdParamPosition()));
        }
        try {
            Object result = AiServiceMethodImplementationSupport.doImplement(createInfo, methodArgs, context, audit);
            if (audit != null) {
                audit.onCompletion(result);
                auditService.complete(audit);
            }
            return result;
        }
        catch (Exception e) {
            log.errorv((Throwable)e, "Execution of {0}#{1} failed", (Object)createInfo.getInterfaceName(), (Object)createInfo.getMethodName());
            if (audit != null) {
                audit.onFailure(e);
                auditService.complete(audit);
            }
            throw e;
        }
    }

    private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[] methodArgs, final QuarkusAiServiceContext context, Audit audit) {
        List<ChatMessage> messages;
        Optional<SystemMessage> systemMessage = AiServiceMethodImplementationSupport.prepareSystemMessage(createInfo, methodArgs);
        UserMessage userMessage = AiServiceMethodImplementationSupport.prepareUserMessage(context, createInfo, methodArgs);
        if (audit != null) {
            audit.initialMessages(systemMessage, userMessage);
        }
        final Object memoryId = AiServiceMethodImplementationSupport.memoryId(createInfo, methodArgs, context.chatMemoryProvider != null);
        if (context.retrievalAugmentor != null) {
            List chatMemory = context.hasChatMemory() ? context.chatMemory(memoryId).messages() : null;
            Metadata metadata = Metadata.from((UserMessage)userMessage, (Object)memoryId, (List)chatMemory);
            userMessage = context.retrievalAugmentor.augment(userMessage, metadata);
        }
        String outputFormatInstructions = createInfo.getUserMessageInfo().getOutputFormatInstructions();
        userMessage = UserMessage.from((String)(userMessage.text() + outputFormatInstructions));
        if (context.hasChatMemory()) {
            ChatMemory chatMemory = context.chatMemory(memoryId);
            if (systemMessage.isPresent()) {
                chatMemory.add((ChatMessage)systemMessage.get());
            }
            chatMemory.add((ChatMessage)userMessage);
        }
        if (context.hasChatMemory()) {
            messages = context.chatMemory(memoryId).messages();
        } else {
            messages = new ArrayList();
            systemMessage.ifPresent(messages::add);
            messages.add((ChatMessage)userMessage);
        }
        Class<?> returnType = createInfo.getReturnType();
        if (returnType.equals(TokenStream.class)) {
            return new AiServiceTokenStream(messages, (AiServiceContext)context, memoryId);
        }
        if (returnType.equals(Multi.class)) {
            return Multi.createFrom().emitter((Consumer)new Consumer<MultiEmitter<? super String>>(){

                @Override
                public void accept(final MultiEmitter<? super String> em) {
                    new AiServiceTokenStream(messages, (AiServiceContext)context, memoryId).onNext(arg_0 -> em.emit(arg_0)).onComplete((Consumer)new Consumer<Response<AiMessage>>(){

                        @Override
                        public void accept(Response<AiMessage> message) {
                            em.complete();
                        }
                    }).onError(arg_0 -> em.fail(arg_0)).start();
                }
            });
        }
        Future<Moderation> moderationFuture = AiServiceMethodImplementationSupport.triggerModerationIfNeeded(context, createInfo, messages);
        log.debug((Object)"Attempting to obtain AI response");
        Response response = context.toolSpecifications == null ? context.chatModel.generate(messages) : context.chatModel.generate(messages, context.toolSpecifications);
        log.debug((Object)"AI response obtained");
        if (audit != null) {
            audit.addLLMToApplicationMessage((Response<AiMessage>)response);
        }
        TokenUsage tokenUsageAccumulator = response.tokenUsage();
        AiServices.verifyModerationIfNeeded(moderationFuture);
        int executionsLeft = 10;
        while (true) {
            if (executionsLeft-- == 0) {
                throw Exceptions.runtime((String)"Something is wrong, exceeded %s sequential tool executions", (Object[])new Object[]{10});
            }
            AiMessage aiMessage = (AiMessage)response.content();
            if (context.hasChatMemory()) {
                context.chatMemory(memoryId).add((ChatMessage)response.content());
            }
            if (!aiMessage.hasToolExecutionRequests()) break;
            ChatMemory chatMemory = context.chatMemory(memoryId);
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                log.debugv("Attempting to execute tool {0}", (Object)toolExecutionRequest);
                ToolExecutor toolExecutor = (ToolExecutor)context.toolExecutors.get(toolExecutionRequest.name());
                if (toolExecutor == null) {
                    throw Exceptions.runtime((String)"Tool executor %s not found", (Object[])new Object[]{toolExecutionRequest.name()});
                }
                String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
                log.debugv("Result of {0} is '{1}'", (Object)toolExecutionRequest, (Object)toolExecutionResult);
                ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecutionRequest, (String)toolExecutionResult);
                if (audit != null) {
                    audit.addApplicationToLLMMessage(toolExecutionResultMessage);
                }
                chatMemory.add((ChatMessage)toolExecutionResultMessage);
            }
            log.debug((Object)"Attempting to obtain AI response");
            response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
            log.debug((Object)"AI response obtained");
            if (audit != null) {
                audit.addLLMToApplicationMessage((Response<AiMessage>)response);
            }
            tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
        }
        response = Response.from((Object)((AiMessage)response.content()), (TokenUsage)tokenUsageAccumulator, (FinishReason)response.finishReason());
        return ServiceOutputParser.parse((Response)response, returnType);
    }

    private static Future<Moderation> triggerModerationIfNeeded(final AiServiceContext context, AiServiceMethodCreateInfo createInfo, final List<ChatMessage> messages) {
        Future<Moderation> moderationFuture = null;
        if (createInfo.isRequiresModeration()) {
            log.debug((Object)"Moderation is required and it will be executed in the background");
            ExecutorService defaultExecutor = (ExecutorService)Infrastructure.getDefaultExecutor();
            moderationFuture = defaultExecutor.submit(new Callable<Moderation>(){

                @Override
                public Moderation call() {
                    List messagesToModerate = AiServices.removeToolMessages((List)messages);
                    log.debug((Object)"Attempting to moderate messages");
                    Moderation result = (Moderation)context.moderationModel.moderate(messagesToModerate).content();
                    log.debug((Object)"Moderation completed");
                    return result;
                }
            });
        }
        return moderationFuture;
    }

    private static Optional<SystemMessage> prepareSystemMessage(AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
        if (createInfo.getSystemMessageInfo().isEmpty()) {
            return Optional.empty();
        }
        AiServiceMethodCreateInfo.TemplateInfo systemMessageInfo = createInfo.getSystemMessageInfo().get();
        HashMap<String, Object> templateParams = new HashMap<String, Object>();
        Map<String, Integer> nameToParamPosition = systemMessageInfo.getNameToParamPosition();
        for (Map.Entry<String, Integer> entry : nameToParamPosition.entrySet()) {
            templateParams.put(entry.getKey(), methodArgs[entry.getValue()]);
        }
        Prompt prompt = PromptTemplate.from((String)systemMessageInfo.getText()).apply(templateParams);
        return Optional.of(prompt.toSystemMessage());
    }

    private static UserMessage prepareUserMessage(AiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
        AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = createInfo.getUserMessageInfo();
        String userName = null;
        if (userMessageInfo.getUserNameParamPosition().isPresent()) {
            userName = methodArgs[userMessageInfo.getUserNameParamPosition().get()].toString();
        }
        if (userMessageInfo.getTemplate().isPresent()) {
            AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.getTemplate().get();
            HashMap<String, Object> templateParams = new HashMap<String, Object>();
            Map<String, Integer> nameToParamPosition = templateInfo.getNameToParamPosition();
            for (Map.Entry<String, Integer> entry : nameToParamPosition.entrySet()) {
                Object value = AiServiceMethodImplementationSupport.transformTemplateParamValue(methodArgs[entry.getValue()]);
                templateParams.put(entry.getKey(), value);
            }
            Prompt prompt = PromptTemplate.from((String)templateInfo.getText()).apply(templateParams);
            return AiServiceMethodImplementationSupport.createUserMessage(userName, prompt.text());
        }
        if (userMessageInfo.getParamPosition().isPresent()) {
            Integer paramIndex = userMessageInfo.getParamPosition().get();
            Object argValue = methodArgs[paramIndex];
            if (argValue == null) {
                throw new IllegalArgumentException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName() + "' because parameter with index " + paramIndex + " is null");
            }
            return AiServiceMethodImplementationSupport.createUserMessage(userName, AiServiceMethodImplementationSupport.toString(argValue));
        }
        throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName() + "'. Please contact the maintainers");
    }

    private static UserMessage createUserMessage(String name, String text) {
        if (name == null) {
            return UserMessage.userMessage((String)text);
        }
        return UserMessage.userMessage((String)name, (String)text);
    }

    private static Object transformTemplateParamValue(Object value) {
        if (value.getClass().isArray()) {
            return Arrays.toString((Object[])value);
        }
        return value;
    }

    private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] methodArgs, boolean hasChatMemoryProvider) {
        ManagedContext requestContext;
        ArcContainer container;
        if (createInfo.getMemoryIdParamPosition().isPresent()) {
            return methodArgs[createInfo.getMemoryIdParamPosition().get()];
        }
        if (hasChatMemoryProvider && (container = Arc.container()) != null && (requestContext = container.requestContext()).isActive()) {
            return requestContext.getState();
        }
        return "default";
    }

    private static String toString(Object arg) {
        if (arg.getClass().isArray()) {
            return AiServiceMethodImplementationSupport.arrayToString(arg);
        }
        if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
            return StructuredPromptProcessor.toPrompt((Object)arg).text();
        }
        return arg.toString();
    }

    private static String arrayToString(Object arg) {
        StringBuilder sb = new StringBuilder("[");
        int length = Array.getLength(arg);
        for (int i = 0; i < length; ++i) {
            sb.append(AiServiceMethodImplementationSupport.toString(Array.get(arg, i)));
            if (i >= length - 1) continue;
            sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }

    public static class Input {
        final QuarkusAiServiceContext context;
        final AiServiceMethodCreateInfo createInfo;
        final Object[] methodArgs;

        public Input(QuarkusAiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
            this.context = context;
            this.createInfo = createInfo;
            this.methodArgs = methodArgs;
        }
    }

    public static interface Wrapper {
        public Object wrap(Input var1, Function<Input, Object> var2);
    }
}

