package io.trino.operator.project;

import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.block.BlockAssertions;
import io.trino.operator.project.PageFieldsToInputParametersRewriter;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.LazyBlock;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionTestUtils;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SqlToRowExpressionTranslator;
import io.trino.sql.tree.Expression;
import io.trino.testing.TestingSession;
import io.trino.transaction.TransactionId;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.IntStream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/project/TestPageFieldsToInputParametersRewriter.class */
public class TestPageFieldsToInputParametersRewriter {
    private static final TypeAnalyzer TYPE_ANALYZER = TypeAnalyzer.createTestingTypeAnalyzer(TestingPlannerContext.PLANNER_CONTEXT);
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().setTransactionId(TransactionId.create()).build();

    /* loaded from: input_file:io/trino/operator/project/TestPageFieldsToInputParametersRewriter$RowExpressionBuilder.class */
    private static class RowExpressionBuilder {
        private final Map<Symbol, Type> symbolTypes = new HashMap();
        private final Map<Symbol, Integer> sourceLayout = new HashMap();
        private final List<Type> types = new LinkedList();

        private RowExpressionBuilder() {
        }

        private static RowExpressionBuilder create() {
            return new RowExpressionBuilder();
        }

        private RowExpressionBuilder addSymbol(String str, Type type) {
            Symbol symbol = new Symbol(str);
            this.symbolTypes.put(symbol, type);
            this.sourceLayout.put(symbol, Integer.valueOf(this.types.size()));
            this.types.add(type);
            return this;
        }

        private RowExpression buildExpression(String str) {
            Expression createExpression = ExpressionTestUtils.createExpression(str, TestingPlannerContext.PLANNER_CONTEXT, TypeProvider.copyOf(this.symbolTypes));
            return SqlToRowExpressionTranslator.translate(createExpression, TestPageFieldsToInputParametersRewriter.TYPE_ANALYZER.getTypes(TestPageFieldsToInputParametersRewriter.TEST_SESSION, TypeProvider.copyOf(this.symbolTypes), createExpression), this.sourceLayout, TestingPlannerContext.PLANNER_CONTEXT.getMetadata(), TestingPlannerContext.PLANNER_CONTEXT.getFunctionManager(), TestPageFieldsToInputParametersRewriter.TEST_SESSION, true);
        }
    }

    @Test
    public void testEagerLoading() {
        RowExpressionBuilder addSymbol = RowExpressionBuilder.create().addSymbol("bigint0", BigintType.BIGINT).addSymbol("bigint1", BigintType.BIGINT);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 + 5"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("CAST((bigint0 * 10) AS INT)"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("COALESCE((bigint0 % 2), bigint0)"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 IN (1, 2, 3)"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 > 0"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 + 1 = 0"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 BETWEEN 1 AND 10"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint0 ELSE null END"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("CASE bigint0 WHEN 1 THEN 1 ELSE -bigint0 END"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("IF(bigint0 >= 150000, 0, 1)"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("IF(bigint0 >= 150000, bigint0, 0)"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("COALESCE(0, bigint0) + bigint0"), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 + (2 * bigint1)"), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("NULLIF(bigint0, bigint1)"), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("COALESCE(CEIL(bigint0 / bigint1), 0)"), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("CASE WHEN (bigint0 > bigint1) THEN 1 ELSE 0 END"), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("CASE WHEN (bigint0 > 0) THEN bigint1 ELSE 0 END"), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("COALESCE(ROUND(bigint0), bigint1)"), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 > 0 AND bigint1 > 0"), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 > 0 OR bigint1 > 0"), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("bigint0 BETWEEN 0 AND bigint1"), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression("IF(bigint1 >= 150000, 0, bigint0)"), 2, ImmutableSet.of(0));
        RowExpressionBuilder addSymbol2 = RowExpressionBuilder.create().addSymbol("array_bigint0", new ArrayType(BigintType.BIGINT)).addSymbol("array_bigint1", new ArrayType(BigintType.BIGINT));
        verifyEagerlyLoadedColumns(addSymbol2.buildExpression("TRANSFORM(array_bigint0, x -> 1)"), 1, ImmutableSet.of());
        verifyEagerlyLoadedColumns(addSymbol2.buildExpression("TRANSFORM(array_bigint0, x -> 2 * x)"), 1, ImmutableSet.of());
        verifyEagerlyLoadedColumns(addSymbol2.buildExpression("ZIP_WITH(array_bigint0, array_bigint1, (x, y) -> 2 * x)"), 2, ImmutableSet.of());
    }

    private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int i) {
        verifyEagerlyLoadedColumns(rowExpression, i, (Set) IntStream.range(0, i).boxed().collect(ImmutableSet.toImmutableSet()));
    }

    private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int i, Set<Integer> set) {
        PageFieldsToInputParametersRewriter.Result rewritePageFieldsToInputParameters = PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters(rowExpression);
        Block[] blockArr = new Block[i];
        for (int i2 = 0; i2 < i; i2++) {
            blockArr[i2] = lazyWrapper(BlockAssertions.createLongSequenceBlock(0, 100));
        }
        Page inputChannels = rewritePageFieldsToInputParameters.getInputChannels().getInputChannels(new Page(blockArr));
        for (int i3 = 0; i3 < i; i3++) {
            Assertions.assertThat(inputChannels.getBlock(i3).isLoaded()).isEqualTo(set.contains(Integer.valueOf(i3)));
        }
    }

    private static LazyBlock lazyWrapper(Block block) {
        int positionCount = block.getPositionCount();
        Objects.requireNonNull(block);
        return new LazyBlock(positionCount, block::getLoadedBlock);
    }
}
