package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.airlift.concurrent.Threads;
import io.trino.RowPagesBuilder;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.AggregationOperator;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
import io.trino.testing.TestingTaskContext;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/operator/TestAggregationOperator.class */
public class TestAggregationOperator {
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();
    private static final TestingAggregationFunction LONG_AVERAGE = FUNCTION_RESOLUTION.getAggregateFunction("avg", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT}));
    private static final TestingAggregationFunction DOUBLE_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE}));
    private static final TestingAggregationFunction LONG_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT}));
    private static final TestingAggregationFunction REAL_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{RealType.REAL}));
    private static final TestingAggregationFunction COUNT = FUNCTION_RESOLUTION.getAggregateFunction("count", ImmutableList.of());
    private ExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;

    @BeforeMethod
    public void setUp() {
        this.executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed(getClass().getSimpleName() + "-%s"));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s"));
    }

    @AfterMethod(alwaysRun = true)
    public void tearDown() {
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
    }

    @Test
    public void testMaskWithDirtyNulls() {
        ImmutableList of = ImmutableList.of(new Page(4, new Block[]{BlockAssertions.createLongsBlock(1, 2, 3, 4), new ByteArrayBlock(4, Optional.of(new boolean[]{true, true, false, false}), new byte[]{0, 27, 0, 75})}));
        AggregationOperator.AggregationOperatorFactory aggregationOperatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(0), OptionalInt.of(1))));
        DriverContext addDriverContext = TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        OperatorAssertion.assertOperatorEquals(aggregationOperatorFactory, addDriverContext, of, MaterializedResult.resultBuilder(addDriverContext.getSession(), new Type[]{BigintType.BIGINT}).row(new Object[]{1L}).build());
    }

    @Test
    public void testDistinctMaskWithNulls() {
        AggregatorFactory createDistinctAggregatorFactory = LONG_SUM.createDistinctAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(0), OptionalInt.of(1));
        DriverContext addDriverContext = TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        AggregationOperator.AggregationOperatorFactory aggregationOperatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(createDistinctAggregatorFactory));
        Block byteArrayBlock = new ByteArrayBlock(4, Optional.of(new boolean[]{true, true, true, true}), new byte[]{1, 1, 1, 1});
        OperatorAssertion.assertOperatorEquals(aggregationOperatorFactory, addDriverContext, ImmutableList.of(new Page(4, new Block[]{BlockAssertions.createLongsBlock(1, 2, 3, 4), byteArrayBlock}), new Page(4, new Block[]{BlockAssertions.createLongsBlock(10, 11, 10, 11), BlockAssertions.createBooleansBlock(true, true, true, true)}), new Page(4, new Block[]{BlockAssertions.createLongsBlock(5, 6, 7, 8), RunLengthEncodedBlock.create(byteArrayBlock.getSingleValueBlock(0), 4)})), MaterializedResult.resultBuilder(addDriverContext.getSession(), new Type[]{BigintType.BIGINT}).row(new Object[]{21L}).build());
    }

    @Test
    public void testAggregation() {
        TestingAggregationFunction aggregateFunction = FUNCTION_RESOLUTION.getAggregateFunction("count", TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR}));
        TestingAggregationFunction aggregateFunction2 = FUNCTION_RESOLUTION.getAggregateFunction("max", TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR}));
        List<Page> build = RowPagesBuilder.rowPagesBuilder(VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR, BigintType.BIGINT, RealType.REAL, DoubleType.DOUBLE, VarcharType.VARCHAR).addSequencePage(100, 0, 0, 300, 500, 400, 500, 500).build();
        AggregationOperator.AggregationOperatorFactory aggregationOperatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(0), OptionalInt.empty()), LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(1), OptionalInt.empty()), LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(1), OptionalInt.empty()), aggregateFunction2.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(2), OptionalInt.empty()), aggregateFunction.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(0), OptionalInt.empty()), LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(3), OptionalInt.empty()), REAL_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(4), OptionalInt.empty()), DOUBLE_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(5), OptionalInt.empty()), aggregateFunction2.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(6), OptionalInt.empty())));
        DriverContext addDriverContext = TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        OperatorAssertion.assertOperatorEquals(aggregationOperatorFactory, addDriverContext, build, MaterializedResult.resultBuilder(addDriverContext.getSession(), new Type[]{BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, RealType.REAL, DoubleType.DOUBLE, VarcharType.VARCHAR}).row(new Object[]{100L, 4950L, Double.valueOf(49.5d), "399", 100L, 54950L, Float.valueOf(44950.0f), Double.valueOf(54950.0d), "599"}).build());
        Assert.assertEquals(addDriverContext.getMemoryUsage(), 0L);
    }

    @Test
    public void testMemoryTracking() throws Exception {
        Page page = (Page) Iterables.getOnlyElement(RowPagesBuilder.rowPagesBuilder(BigintType.BIGINT).addSequencePage(100, 0).build());
        AggregationOperator.AggregationOperatorFactory aggregationOperatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, ImmutableList.of(0), OptionalInt.empty())));
        DriverContext addDriverContext = TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        Operator createOperator = aggregationOperatorFactory.createOperator(addDriverContext);
        try {
            Assert.assertTrue(createOperator.needsInput());
            createOperator.addInput(page);
            Assertions.assertThat(addDriverContext.getMemoryUsage()).isGreaterThan(0L);
            OperatorAssertion.toPages(createOperator, (Iterator<Page>) Collections.emptyIterator());
            if (createOperator != null) {
                createOperator.close();
            }
            Assert.assertEquals(addDriverContext.getMemoryUsage(), 0L);
        } catch (Throwable th) {
            if (createOperator != null) {
                try {
                    createOperator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
