package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.InternalFunctionBundle;
import io.trino.metadata.SqlFunction;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimeZoneKey;
import io.trino.spi.type.VarcharType;
import io.trino.testing.TestingSession;
import io.trino.util.StructuralTestUtil;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/operator/scalar/TestLambdaExpression.class */
public class TestLambdaExpression extends AbstractTestFunctions {
    public TestLambdaExpression() {
        super(TestingSession.testSessionBuilder().setTimeZoneKey(TimeZoneKey.getTimeZoneKey("Pacific/Kiritimati")).build());
    }

    @BeforeClass
    public void setUp() {
        this.functionAssertions.addFunctions(new InternalFunctionBundle(new SqlFunction[]{ApplyFunction.APPLY_FUNCTION, InvokeFunction.INVOKE_FUNCTION}));
    }

    @Test
    public void testBasic() {
        assertFunction("apply(5, x -> x + 1)", IntegerType.INTEGER, 6);
        assertFunction("apply(5 + RANDOM(1), x -> x + 1)", IntegerType.INTEGER, 6);
    }

    @Test
    public void testParameterName() {
        assertFunction("apply(5, " + quote("a.b c; d ' \n \\n \"") + " -> " + quote("a.b c; d ' \n \\n \"") + " * 2)", IntegerType.INTEGER, 10);
    }

    @Test
    public void testNull() {
        assertFunction("apply(3, x -> x + 1)", IntegerType.INTEGER, 4);
        assertFunction("apply(NULL, x -> x + 1)", IntegerType.INTEGER, null);
        assertFunction("apply(CAST (NULL AS INTEGER), x -> x + 1)", IntegerType.INTEGER, null);
        assertFunction("apply(3, x -> x IS NULL)", BooleanType.BOOLEAN, false);
        assertFunction("apply(NULL, x -> x IS NULL)", BooleanType.BOOLEAN, true);
        assertFunction("apply(CAST (NULL AS INTEGER), x -> x IS NULL)", BooleanType.BOOLEAN, true);
    }

    @Test
    public void testUnreferencedLambdaArgument() {
        assertFunction("apply(5, x -> 6)", IntegerType.INTEGER, 6);
    }

    @Test
    public void testLambdaWithoutArgument() {
        assertFunction("invoke(() -> 42)", IntegerType.INTEGER, 42);
    }

    @Test
    public void testSessionDependent() {
        assertFunction("apply('timezone: ', x -> x || current_timezone())", VarcharType.VARCHAR, "timezone: Pacific/Kiritimati");
    }

    @Test
    public void testInstanceFunction() {
        assertFunction("apply(ARRAY[2], x -> concat(ARRAY [1], x))", new ArrayType(IntegerType.INTEGER), ImmutableList.of(1, 2));
    }

    @Test
    public void testNestedLambda() {
        assertFunction("apply(11, x -> apply(x + 7, y -> apply(y * 3, z -> z * 5) + 1) * 2)", IntegerType.INTEGER, 542);
        assertFunction("apply(11, x -> apply(x + 7, x -> apply(x * 3, x -> x * 5) + 1) * 2)", IntegerType.INTEGER, 542);
    }

    @Test
    public void testRowAccess() {
        assertFunction("apply(CAST(ROW(1, 'a') AS ROW(x INTEGER, y VARCHAR)), r -> r[1])", IntegerType.INTEGER, 1);
        assertFunction("apply(CAST(ROW(1, 'a') AS ROW(x INTEGER, y VARCHAR)), r -> r[2])", VarcharType.VARCHAR, "a");
    }

    @Test
    public void testBind() {
        assertFunction("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", IntegerType.INTEGER, 99);
        assertFunction("invoke(\"$internal$bind\"(8, x -> x + 1))", IntegerType.INTEGER, 9);
        assertFunction("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", IntegerType.INTEGER, 999);
        assertFunction("invoke(\"$internal$bind\"(90, 9, (x, y) -> x + y))", IntegerType.INTEGER, 99);
    }

    @Test
    public void testCoercion() {
        assertFunction("apply(90, x -> x + 9.0E0)", DoubleType.DOUBLE, Double.valueOf(99.0d));
        assertFunction("apply(90, \"$internal$bind\"(9.0E0, (x, y) -> x + y))", DoubleType.DOUBLE, Double.valueOf(99.0d));
        assertFunction("invoke(\"$internal$bind\"(8, x -> x + 1.0E0))", DoubleType.DOUBLE, Double.valueOf(9.0d));
    }

    @Test
    public void testTypeCombinations() {
        assertFunction("apply(25, x -> x + 1)", IntegerType.INTEGER, 26);
        assertFunction("apply(25, x -> x + 1.0E0)", DoubleType.DOUBLE, Double.valueOf(26.0d));
        assertFunction("apply(25, x -> x = 25)", BooleanType.BOOLEAN, true);
        assertFunction("apply(25, x -> to_base(x, 16))", VarcharType.createVarcharType(64), "19");
        assertFunction("apply(25, x -> ARRAY[x + 1])", new ArrayType(IntegerType.INTEGER), ImmutableList.of(26));
        assertFunction("apply(25.6E0, x -> CAST(x AS BIGINT))", BigintType.BIGINT, 26L);
        assertFunction("apply(25.6E0, x -> x + 1.0E0)", DoubleType.DOUBLE, Double.valueOf(26.6d));
        assertFunction("apply(25.6E0, x -> x = 25.6E0)", BooleanType.BOOLEAN, true);
        assertFunction("apply(25.6E0, x -> CAST(x AS VARCHAR))", VarcharType.createUnboundedVarcharType(), "2.56E1");
        assertFunction("apply(25.6E0, x -> MAP(ARRAY[x + 1], ARRAY[true]))", StructuralTestUtil.mapType(DoubleType.DOUBLE, BooleanType.BOOLEAN), ImmutableMap.of(Double.valueOf(26.6d), true));
        assertFunction("apply(true, x -> if(x, 25, 26))", IntegerType.INTEGER, 25);
        assertFunction("apply(false, x -> if(x, 25.6E0, 28.9E0))", DoubleType.DOUBLE, Double.valueOf(28.9d));
        assertFunction("apply(true, x -> not x)", BooleanType.BOOLEAN, false);
        assertFunction("apply(false, x -> CAST(x AS VARCHAR))", VarcharType.createUnboundedVarcharType(), "false");
        assertFunction("apply(true, x -> ARRAY[x])", new ArrayType(BooleanType.BOOLEAN), ImmutableList.of(true));
        assertFunction("apply('41', x -> from_base(x, 16))", BigintType.BIGINT, 65L);
        assertFunction("apply('25.6E0', x -> CAST(x AS DOUBLE))", DoubleType.DOUBLE, Double.valueOf(25.6d));
        assertFunction("apply('abc', x -> 'abc' = x)", BooleanType.BOOLEAN, true);
        assertFunction("apply('abc', x -> x || x)", VarcharType.createUnboundedVarcharType(), "abcabc");
        assertFunction("apply('123', x -> ROW(x, CAST(x AS INTEGER), x > '0'))", RowType.anonymous(ImmutableList.of(VarcharType.createVarcharType(3), IntegerType.INTEGER, BooleanType.BOOLEAN)), ImmutableList.of("123", 123, true));
        assertFunction("apply(ARRAY['abc', NULL, '123'], x -> from_base(x[3], 10))", BigintType.BIGINT, 123L);
        assertFunction("apply(ARRAY['abc', NULL, '123'], x -> CAST(x[3] AS DOUBLE))", DoubleType.DOUBLE, Double.valueOf(123.0d));
        assertFunction("apply(ARRAY['abc', NULL, '123'], x -> x[2] IS NULL)", BooleanType.BOOLEAN, true);
        assertFunction("apply(ARRAY['abc', NULL, '123'], x -> x[2])", VarcharType.createVarcharType(3), null);
        assertFunction("apply(MAP(ARRAY['abc', 'def'], ARRAY[123, 456]), x -> map_keys(x))", new ArrayType(VarcharType.createVarcharType(3)), ImmutableList.of("abc", "def"));
    }

    @Test
    public void testFunctionParameter() {
        assertInvalidFunction("count(x -> x)", StandardErrorCode.FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function count. Expected: count(), count(t) T");
        assertInvalidFunction("max(x -> x)", StandardErrorCode.FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function max. Expected: max(t) T:orderable, max(e, bigint) E:orderable");
        assertInvalidFunction("sqrt(x -> x)", StandardErrorCode.FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function sqrt. Expected: sqrt(double)");
        assertInvalidFunction("sqrt(x -> x, 123, x -> x)", StandardErrorCode.FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>, integer, <function>) for function sqrt. Expected: sqrt(double)");
        assertInvalidFunction("pow(x -> x, 123)", StandardErrorCode.FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>, integer) for function pow. Expected: pow(double, double)");
        assertInvalidFunction("pow(123, x -> x)", StandardErrorCode.FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (integer, <function>) for function pow. Expected: pow(double, double)");
    }

    private static String quote(String str) {
        return "\"" + str.replace("\"", "\"\"") + "\"";
    }
}
