/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.deployment;

import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolMemoryId;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import io.quarkiverse.langchain4j.deployment.AiServicesProcessor;
import io.quarkiverse.langchain4j.deployment.DotNames;
import io.quarkiverse.langchain4j.deployment.HashUtil;
import io.quarkiverse.langchain4j.deployment.JandexUtil;
import io.quarkiverse.langchain4j.deployment.ToolsMetadataBuildItem;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.tool.ToolParametersObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper;
import io.quarkiverse.langchain4j.runtime.tool.ToolSpecificationObjectSubstitution;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem;
import io.quarkus.builder.item.BuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.GeneratedClassGizmoAdaptor;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.BytecodeTransformerBuildItem;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.recording.RecorderContext;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
import io.quarkus.gizmo.ClassTransformer;
import io.quarkus.gizmo.FieldDescriptor;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;
import org.jboss.logging.Logger;
import org.objectweb.asm.ClassVisitor;

public class ToolProcessor {
    private static final Logger log = Logger.getLogger(AiServicesProcessor.class);
    private static final DotName TOOL = DotName.createSimple(Tool.class);
    private static final DotName TOOL_MEMORY_ID = DotName.createSimple(ToolMemoryId.class);
    private static final DotName P = DotName.createSimple(P.class);
    private static final MethodDescriptor METHOD_METADATA_CTOR = MethodDescriptor.ofConstructor(ToolInvoker.MethodMetadata.class, (Class[])new Class[]{Boolean.TYPE, Map.class, Integer.class});
    private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class, (Class[])new Class[0]);
    public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, (String)"put", Object.class, (Class[])new Class[]{Object.class, Object.class});

    @BuildStep
    public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
        boolean addOpenTelemetrySpan = capabilities.isPresent("io.quarkus.opentelemetry.tracer");
        if (addOpenTelemetrySpan) {
            additionalBeanProducer.produce((BuildItem)AdditionalBeanBuildItem.builder().addBeanClass(ToolSpanWrapper.class).build());
        }
    }

    @BuildStep
    @Record(value=ExecutionTime.STATIC_INIT)
    public void handleTools(CombinedIndexBuildItem indexBuildItem, ToolsRecorder recorder, RecorderContext recorderContext, BuildProducer<BytecodeTransformerBuildItem> transformerProducer, BuildProducer<GeneratedClassBuildItem> generatedClassProducer, BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer, BuildProducer<ValidationPhaseBuildItem.ValidationErrorBuildItem> validation, BuildProducer<ToolsMetadataBuildItem> toolsMetadataProducer) {
        recorderContext.registerSubstitution(ToolSpecification.class, ToolSpecificationObjectSubstitution.Serialized.class, ToolSpecificationObjectSubstitution.class);
        recorderContext.registerSubstitution(ToolParameters.class, ToolParametersObjectSubstitution.Serialized.class, ToolParametersObjectSubstitution.class);
        IndexView index = indexBuildItem.getIndex();
        Collection instances = index.getAnnotations(TOOL);
        HashMap<String, List<ToolMethodCreateInfo>> metadata = new HashMap<String, List<ToolMethodCreateInfo>>();
        ArrayList<String> generatedInvokerClasses = new ArrayList<String>();
        ArrayList<String> generatedArgumentMapperClasses = new ArrayList<String>();
        if (!instances.isEmpty()) {
            GeneratedClassGizmoAdaptor classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
            HashMap<DotName, List> methodsPerClass = new HashMap<DotName, List>();
            for (AnnotationInstance instance : instances) {
                if (instance.target().kind() != AnnotationTarget.Kind.METHOD) continue;
                MethodInfo methodInfo = instance.target().asMethod();
                ClassInfo classInfo = methodInfo.declaringClass();
                if (classInfo.isInterface() || Modifier.isAbstract(classInfo.flags())) {
                    validation.produce((BuildItem)new ValidationPhaseBuildItem.ValidationErrorBuildItem(new Throwable[]{new IllegalStateException("@Tool is only supported on non-abstract classes, all other usages are ignored. Offending method is '" + methodInfo.declaringClass().name().toString() + "#" + methodInfo.name() + "'")}));
                    continue;
                }
                DotName declaringClassName = classInfo.name();
                methodsPerClass.computeIfAbsent(declaringClassName, n -> new ArrayList()).add(methodInfo);
            }
            boolean validationErrorFound = false;
            HashMap<String, ClassInfo> discoveredTools = new HashMap<String, ClassInfo>();
            for (Map.Entry entry : methodsPerClass.entrySet()) {
                DotName className = (DotName)entry.getKey();
                List toolMethods = (List)entry.getValue();
                ArrayList<MethodInfo> privateMethods = new ArrayList<MethodInfo>();
                for (MethodInfo toolMethod : toolMethods) {
                    if (discoveredTools.containsKey(toolMethod.name())) {
                        validation.produce((BuildItem)new ValidationPhaseBuildItem.ValidationErrorBuildItem(new Throwable[]{new IllegalStateException("A tool with the name '" + toolMethod.name() + "' from class '" + className + "' is already declared in class '" + discoveredTools.get(toolMethod.name()) + "'. Tools method name must be unique.")}));
                        validationErrorFound = true;
                        continue;
                    }
                    discoveredTools.put(toolMethod.name(), toolMethod.declaringClass());
                    if (!Modifier.isPrivate(toolMethod.flags())) continue;
                    privateMethods.add(toolMethod);
                }
                if (!privateMethods.isEmpty()) {
                    transformerProducer.produce((BuildItem)new BytecodeTransformerBuildItem(className.toString(), (BiFunction)new RemovePrivateFromMethodsVisitor(privateMethods)));
                }
                if (validationErrorFound) {
                    return;
                }
                for (MethodInfo toolMethod : toolMethods) {
                    AnnotationInstance instance = toolMethod.annotation(TOOL);
                    AnnotationValue nameValue = instance.value("name");
                    AnnotationValue descriptionValue = instance.value();
                    String toolName = ToolProcessor.getToolName(nameValue, toolMethod);
                    String toolDescription = this.getToolDescription(descriptionValue);
                    ToolSpecification.Builder builder = ToolSpecification.builder().name(toolName).description(toolDescription);
                    MethodParameterInfo memoryIdParameter = null;
                    for (MethodParameterInfo parameter : toolMethod.parameters()) {
                        if (parameter.hasAnnotation(TOOL_MEMORY_ID)) {
                            memoryIdParameter = parameter;
                            continue;
                        }
                        builder.addParameter(parameter.name(), this.toJsonSchemaProperties(parameter, index));
                    }
                    Map<String, Integer> nameToParamPosition = toolMethod.parameters().stream().collect(Collectors.toMap(MethodParameterInfo::name, i -> i.position()));
                    String methodSignature = ToolProcessor.createUniqueSignature(toolMethod);
                    String invokerClassName = ToolProcessor.generateInvoker(toolMethod, (ClassOutput)classOutput, nameToParamPosition, memoryIdParameter != null ? Short.valueOf(memoryIdParameter.position()) : null, methodSignature);
                    generatedInvokerClasses.add(invokerClassName);
                    String argumentMapperClassName = this.generateArgumentMapper(toolMethod, (ClassOutput)classOutput, methodSignature);
                    generatedArgumentMapperClasses.add(argumentMapperClassName);
                    ToolSpecification toolSpecification = builder.build();
                    ToolMethodCreateInfo methodCreateInfo = new ToolMethodCreateInfo(toolMethod.name(), invokerClassName, toolSpecification, argumentMapperClassName);
                    metadata.computeIfAbsent(className.toString(), c -> new ArrayList()).add(methodCreateInfo);
                }
            }
        }
        if (!generatedInvokerClasses.isEmpty()) {
            reflectiveClassProducer.produce((BuildItem)ReflectiveClassBuildItem.builder((String[])((String[])generatedInvokerClasses.toArray(String[]::new))).constructors(true).build());
        }
        if (!generatedArgumentMapperClasses.isEmpty()) {
            reflectiveClassProducer.produce((BuildItem)ReflectiveClassBuildItem.builder((String[])((String[])generatedArgumentMapperClasses.toArray(String[]::new))).fields(true).constructors(true).build());
        }
        toolsMetadataProducer.produce((BuildItem)new ToolsMetadataBuildItem(metadata));
        recorder.setMetadata(metadata);
    }

    private static String createUniqueSignature(MethodInfo toolMethod) {
        StringBuilder sigBuilder = new StringBuilder();
        sigBuilder.append(toolMethod.name()).append(toolMethod.returnType().name().toString());
        for (MethodParameterInfo t : toolMethod.parameters()) {
            sigBuilder.append(t.type().name().toString());
        }
        return sigBuilder.toString();
    }

    private static String getToolName(AnnotationValue nameValue, MethodInfo methodInfo) {
        if (nameValue == null) {
            return methodInfo.name();
        }
        String annotationValue = nameValue.asString();
        if (annotationValue.isEmpty()) {
            return methodInfo.name();
        }
        return annotationValue;
    }

    private String getToolDescription(AnnotationValue descriptionValue) {
        if (descriptionValue == null) {
            return "";
        }
        return String.join((CharSequence)"\n", descriptionValue.asStringArray());
    }

    private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOutput, Map<String, Integer> nameToParamPosition, Short memoryIdParamPosition, String methodSignature) {
        String implClassName = methodInfo.declaringClass().name() + "$$QuarkusInvoker$" + methodInfo.name() + "_" + HashUtil.sha1(methodSignature);
        try (ClassCreator classCreator = ClassCreator.builder().classOutput(classOutput).className(implClassName).interfaces(new Class[]{ToolInvoker.class}).build();){
            boolean toolReturnsVoid;
            ResultHandle result;
            MethodCreator invokeMc = classCreator.getMethodCreator(MethodDescriptor.ofMethod((Object)implClassName, (String)"invoke", Object.class, (Object[])new Object[]{Object.class, Object[].class}));
            if (methodInfo.parametersCount() > 0) {
                ArrayList<ResultHandle> argumentHandles = new ArrayList<ResultHandle>(methodInfo.parametersCount());
                for (int i = 0; i < methodInfo.parametersCount(); ++i) {
                    argumentHandles.add(invokeMc.readArrayValue(invokeMc.getMethodParam(1), i));
                }
                ResultHandle[] targetMethodHandles = argumentHandles.toArray(new ResultHandle[0]);
                result = invokeMc.invokeVirtualMethod(MethodDescriptor.of((MethodInfo)methodInfo), invokeMc.getMethodParam(0), targetMethodHandles);
            } else {
                result = invokeMc.invokeVirtualMethod(MethodDescriptor.of((MethodInfo)methodInfo), invokeMc.getMethodParam(0), new ResultHandle[0]);
            }
            boolean bl = toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID;
            if (toolReturnsVoid) {
                invokeMc.returnValue(invokeMc.load("Success"));
            } else {
                invokeMc.returnValue(result);
            }
            MethodCreator methodMetadataMc = classCreator.getMethodCreator(MethodDescriptor.ofMethod((Object)implClassName, (String)"methodMetadata", ToolInvoker.MethodMetadata.class, (Object[])new Object[0]));
            ResultHandle nameToParamPositionHandle = methodMetadataMc.newInstance(HASHMAP_CTOR, new ResultHandle[0]);
            for (Map.Entry<String, Integer> entry : nameToParamPosition.entrySet()) {
                methodMetadataMc.invokeInterfaceMethod(MAP_PUT, nameToParamPositionHandle, new ResultHandle[]{methodMetadataMc.load(entry.getKey()), methodMetadataMc.load(entry.getValue().intValue())});
            }
            ResultHandle resultHandle = methodMetadataMc.newInstance(METHOD_METADATA_CTOR, new ResultHandle[]{methodMetadataMc.load(toolReturnsVoid), nameToParamPositionHandle, memoryIdParamPosition != null ? methodMetadataMc.load(Integer.valueOf(memoryIdParamPosition.shortValue()).intValue()) : methodMetadataMc.loadNull()});
            methodMetadataMc.returnValue(resultHandle);
        }
        return implClassName;
    }

    private String generateArgumentMapper(MethodInfo methodInfo, ClassOutput classOutput, String methodSignature) {
        String implClassName = methodInfo.declaringClass().name() + "$$QuarkusToolArgumentMapper$" + methodInfo.name() + "_" + HashUtil.sha1(methodSignature);
        try (ClassCreator classCreator = ClassCreator.builder().classOutput(classOutput).className(implClassName).interfaces(new Class[]{Mappable.class}).build();){
            ArrayList<FieldDescriptor> fieldDescriptors = new ArrayList<FieldDescriptor>();
            for (MethodParameterInfo parameter : methodInfo.parameters()) {
                FieldDescriptor fieldDescriptor = FieldDescriptor.of((String)implClassName, (String)parameter.name(), (String)parameter.type().name().toString());
                fieldDescriptors.add(fieldDescriptor);
                classCreator.getFieldCreator(fieldDescriptor).setModifiers(1);
            }
            MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.ofMethod((Object)implClassName, (String)"obtainFieldValuesMap", Map.class, (Object[])new Object[0]));
            ResultHandle mapHandle = mc.newInstance(MethodDescriptor.ofConstructor(HashMap.class, (Class[])new Class[0]), new ResultHandle[0]);
            for (FieldDescriptor field : fieldDescriptors) {
                ResultHandle fieldValue = mc.readInstanceField(field, mc.getThis());
                mc.invokeInterfaceMethod(MAP_PUT, mapHandle, new ResultHandle[]{mc.load(field.getName()), fieldValue});
            }
            mc.returnValue(mapHandle);
        }
        return implClassName;
    }

    private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo parameter, IndexView index) {
        JsonSchemaProperty description;
        Type type = parameter.type();
        DotName typeName = parameter.type().name();
        AnnotationInstance pInstance = parameter.annotation(P);
        JsonSchemaProperty jsonSchemaProperty = description = pInstance == null ? null : JsonSchemaProperty.description((String)pInstance.value().asString());
        if (DotNames.STRING.equals((Object)typeName) || DotNames.CHARACTER.equals((Object)typeName) || DotNames.PRIMITIVE_CHAR.equals((Object)typeName)) {
            return this.removeNulls(JsonSchemaProperty.STRING, description);
        }
        if (DotNames.BOOLEAN.equals((Object)typeName) || DotNames.PRIMITIVE_BOOLEAN.equals((Object)typeName)) {
            return this.removeNulls(JsonSchemaProperty.BOOLEAN, description);
        }
        if (DotNames.BYTE.equals((Object)typeName) || DotNames.PRIMITIVE_BYTE.equals((Object)typeName) || DotNames.SHORT.equals((Object)typeName) || DotNames.PRIMITIVE_SHORT.equals((Object)typeName) || DotNames.INTEGER.equals((Object)typeName) || DotNames.PRIMITIVE_INT.equals((Object)typeName) || DotNames.LONG.equals((Object)typeName) || DotNames.PRIMITIVE_LONG.equals((Object)typeName) || DotNames.BIG_INTEGER.equals((Object)typeName)) {
            return this.removeNulls(JsonSchemaProperty.INTEGER, description);
        }
        if (DotNames.FLOAT.equals((Object)typeName) || DotNames.PRIMITIVE_FLOAT.equals((Object)typeName) || DotNames.DOUBLE.equals((Object)typeName) || DotNames.PRIMITIVE_DOUBLE.equals((Object)typeName) || DotNames.BIG_DECIMAL.equals((Object)typeName)) {
            return this.removeNulls(JsonSchemaProperty.NUMBER, description);
        }
        if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals((Object)typeName) || DotNames.SET.equals((Object)typeName)) {
            return this.removeNulls(JsonSchemaProperty.ARRAY, description);
        }
        if (this.isEnum(type, index)) {
            return this.removeNulls(JsonSchemaProperty.STRING, JsonSchemaProperty.enums((Object[])ToolProcessor.enumConstants(type)), description);
        }
        return this.removeNulls(JsonSchemaProperty.OBJECT, description);
    }

    private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty ... properties) {
        return Arrays.stream(properties).filter(Objects::nonNull).collect(Collectors.toList());
    }

    private boolean isEnum(Type returnType, IndexView index) {
        if (returnType.kind() != Type.Kind.CLASS) {
            return false;
        }
        ClassInfo maybeEnum = index.getClassByName(returnType.name());
        return maybeEnum != null && maybeEnum.isEnum();
    }

    private static Object[] enumConstants(Type type) {
        return JandexUtil.load(type, Thread.currentThread().getContextClassLoader()).getEnumConstants();
    }

    private static class RemovePrivateFromMethodsVisitor
    implements BiFunction<String, ClassVisitor, ClassVisitor> {
        private final List<MethodInfo> privateMethods;

        private RemovePrivateFromMethodsVisitor(List<MethodInfo> privateMethods) {
            this.privateMethods = privateMethods;
        }

        @Override
        public ClassVisitor apply(String className, ClassVisitor classVisitor) {
            ClassTransformer transformer = new ClassTransformer(className);
            for (MethodInfo method : this.privateMethods) {
                transformer.modifyMethod(MethodDescriptor.of((MethodInfo)method)).removeModifiers(2);
            }
            return transformer.applyTo(classVisitor);
        }
    }
}

