package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.VariableInstruction;
import io.trino.metadata.SqlScalarFunction;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.ArrayValueBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedArrayValueBuilder;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.LambdaMetafactoryGenerator;
import io.trino.sql.gen.SqlTypeBytecodeExpression;
import io.trino.sql.gen.lambda.UnaryFunctionInterface;
import io.trino.type.FunctionType;
import io.trino.type.UnknownType;
import io.trino.util.CompilerUtils;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Optional;

/* loaded from: input_file:io/trino/operator/scalar/ArrayTransformFunction.class */
public final class ArrayTransformFunction extends SqlScalarFunction {
    private static final MethodHandle CREATE_STATE;
    public static final ArrayTransformFunction ARRAY_TRANSFORM_FUNCTION;

    private ArrayTransformFunction() {
        super(FunctionMetadata.scalarBuilder().signature(Signature.builder().name("transform").typeVariable("T").typeVariable("U").returnType(TypeSignature.arrayType(new TypeSignature("U", new TypeSignatureParameter[0]))).argumentType(TypeSignature.arrayType(new TypeSignature("T", new TypeSignatureParameter[0]))).argumentType(TypeSignature.functionType(new TypeSignature("T", new TypeSignatureParameter[0]), new TypeSignature[]{new TypeSignature("U", new TypeSignatureParameter[0])})).build()).nondeterministic().description("Apply lambda to each element of the array").build());
    }

    @Override // io.trino.metadata.SqlScalarFunction
    protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) {
        Type elementType = ((ArrayType) boundSignature.getArgumentTypes().get(0)).getElementType();
        ArrayType returnType = boundSignature.getReturnType();
        return new ChoicesSpecializedSqlScalarFunction(boundSignature, InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, ImmutableList.of(InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.FUNCTION), ImmutableList.of(UnaryFunctionInterface.class), generateTransform(elementType, returnType.getElementType()), Optional.of(CREATE_STATE.bindTo(returnType)));
    }

    private static MethodHandle generateTransform(Type type, Type type2) {
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("ArrayTransform"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        classDefinition.declareDefaultConstructor(Access.a(new Access[]{Access.PRIVATE}));
        MethodDefinition generateTransformValueInner = generateTransformValueInner(classDefinition, callSiteBinder, type, type2);
        Parameter arg = Parameter.arg("arrayValueBuilder", BufferedArrayValueBuilder.class);
        Parameter arg2 = Parameter.arg("block", Block.class);
        Parameter arg3 = Parameter.arg(FunctionType.NAME, UnaryFunctionInterface.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC, Access.STATIC}), "transform", ParameterizedType.type(Block.class), ImmutableList.of(arg, arg2, arg3));
        BytecodeExpression generateMetafactory = LambdaMetafactoryGenerator.generateMetafactory(ArrayValueBuilder.class, generateTransformValueInner, ImmutableList.of(arg2, arg3));
        declareMethod.getBody().append(arg.invoke("build", Block.class, new BytecodeExpression[]{arg2.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0]), generateMetafactory}).ret());
        try {
            return MethodHandles.lookup().findStatic(CompilerUtils.defineClass(classDefinition, Object.class, callSiteBinder.getBindings(), ArrayTransformFunction.class.getClassLoader()), "transform", MethodType.methodType(Block.class, BufferedArrayValueBuilder.class, Block.class, UnaryFunctionInterface.class));
        } catch (ReflectiveOperationException e) {
            throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, e);
        }
    }

    private static MethodDefinition generateTransformValueInner(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, Type type, Type type2) {
        Class wrap = Primitives.wrap(type.getJavaType());
        Class wrap2 = Primitives.wrap(type2.getJavaType());
        BytecodeExpression arg = Parameter.arg("block", Block.class);
        Parameter arg2 = Parameter.arg(FunctionType.NAME, UnaryFunctionInterface.class);
        BytecodeExpression arg3 = Parameter.arg("elementBuilder", BlockBuilder.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PRIVATE, Access.STATIC}), "transformValue", ParameterizedType.type(Void.TYPE), ImmutableList.of(arg, arg2, arg3));
        BytecodeBlock body = declareMethod.getBody();
        Scope scope = declareMethod.getScope();
        Variable declareVariable = scope.declareVariable(Integer.TYPE, "positionCount");
        BytecodeExpression declareVariable2 = scope.declareVariable(Integer.TYPE, "position");
        Variable declareVariable3 = scope.declareVariable(wrap, "inputElement");
        Variable declareVariable4 = scope.declareVariable(wrap2, "outputElement");
        body.append(declareVariable.set(arg.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0])));
        body.append(new ForLoop().initialize(declareVariable2.set(BytecodeExpressions.constantInt(0))).condition(BytecodeExpressions.lessThan(declareVariable2, declareVariable)).update(VariableInstruction.incrementVariable(declareVariable2, (byte) 1)).body(new BytecodeBlock().append(!type.equals(UnknownType.UNKNOWN) ? new IfStatement().condition(arg.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{declareVariable2})).ifTrue(declareVariable3.set(BytecodeExpressions.constantNull(wrap))).ifFalse(declareVariable3.set(SqlTypeBytecodeExpression.constantType(callSiteBinder, type).getValue(arg, declareVariable2).cast(wrap))) : new BytecodeBlock().append(declareVariable3.set(BytecodeExpressions.constantNull(wrap)))).append(declareVariable4.set(arg2.invoke("apply", Object.class, new BytecodeExpression[]{declareVariable3.cast(Object.class)}).cast(wrap2))).append(!type2.equals(UnknownType.UNKNOWN) ? new IfStatement().condition(BytecodeExpressions.equal(declareVariable4, BytecodeExpressions.constantNull(wrap2))).ifTrue(arg3.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0]).pop()).ifFalse(SqlTypeBytecodeExpression.constantType(callSiteBinder, type2).writeValue(arg3, declareVariable4.cast(type2.getJavaType()))) : new BytecodeBlock().append(arg3.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0]).pop()))));
        body.ret();
        return declareMethod;
    }

    static {
        try {
            CREATE_STATE = MethodHandles.lookup().findStatic(BufferedArrayValueBuilder.class, "createBuffered", MethodType.methodType((Class<?>) BufferedArrayValueBuilder.class, (Class<?>) ArrayType.class));
            ARRAY_TRANSFORM_FUNCTION = new ArrayTransformFunction();
        } catch (ReflectiveOperationException e) {
            throw new ExceptionInInitializerError(e);
        }
    }
}
