package io.trino.operator.output;

import com.google.common.collect.ImmutableList;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.buffer.OutputBufferStateMachine;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.execution.buffer.PartitionedOutputBuffer;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.jmh.Benchmarks;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.memory.context.SimpleLocalMemoryContext;
import io.trino.operator.BenchmarkWindowOperator;
import io.trino.operator.BucketPartitionFunction;
import io.trino.operator.DriverContext;
import io.trino.operator.PageTestUtils;
import io.trino.operator.PrecomputedHashGenerator;
import io.trino.operator.output.PartitionedOutputOperator;
import io.trino.spi.Page;
import io.trino.spi.QueryId;
import io.trino.spi.block.Block;
import io.trino.spi.block.RowBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.TestingBlockEncodingSerde;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.HashBucketFunction;
import io.trino.sql.planner.TestTableScanNodePartitioning;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingTaskContext;
import io.trino.type.BlockTypeOperators;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import org.junit.jupiter.api.Test;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@Warmup(iterations = TestTableScanNodePartitioning.BUCKET_COUNT, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Measurement(iterations = TestTableScanNodePartitioning.BUCKET_COUNT, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(BenchmarkWindowOperator.Context.NUMBER_OF_GROUP_COLUMNS)
@BenchmarkMode({Mode.AverageTime})
/* loaded from: input_file:io/trino/operator/output/BenchmarkPartitionedOutputOperator.class */
public class BenchmarkPartitionedOutputOperator {
    private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(new BlockTypeOperators());

    @State(Scope.Thread)
    /* loaded from: input_file:io/trino/operator/output/BenchmarkPartitionedOutputOperator$BenchmarkData.class */
    public static class BenchmarkData {
        private static final int DEFAULT_POSITION_COUNT = 8192;
        private static final DataSize MAX_PARTITION_BUFFER_SIZE = DataSize.of(256, DataSize.Unit.MEGABYTE);
        private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(Threads.daemonThreadsNamed("BenchmarkPartitionedOutputOperator-executor-%s"));
        private static final ScheduledExecutorService SCHEDULER = Executors.newScheduledThreadPool(1, Threads.daemonThreadsNamed("BenchmarkPartitionedOutputOperator-scheduledExecutor-%s"));

        @Param({"true", "false"})
        private boolean enableCompression;
        private OptionalInt nullChannel;
        private List<Type> types;
        private int pageCount;
        private Page dataPage;
        private Blackhole blackhole;

        @Param({"2", "16", "256"})
        private int partitionCount = 256;

        @Param({"1", "2"})
        private int channelCount = 1;

        @Param({"8192"})
        private int positionCount = DEFAULT_POSITION_COUNT;

        @Param({"BIGINT", "BIGINT_PARTITION_CHANNEL_SKEWED", "DICTIONARY_BIGINT", "RLE_BIGINT", "BIGINT_PARTITION_CHANNEL_20_PERCENT", "BIGINT_PARTITION_CHANNEL_DICTIONARY_20_PERCENT", "BIGINT_PARTITION_CHANNEL_DICTIONARY_50_PERCENT", "BIGINT_PARTITION_CHANNEL_DICTIONARY_80_PERCENT", "BIGINT_PARTITION_CHANNEL_DICTIONARY_100_PERCENT", "BIGINT_PARTITION_CHANNEL_DICTIONARY_100_PERCENT_MINUS_1", "BIGINT_PARTITION_CHANNEL_RLE", "BIGINT_PARTITION_CHANNEL_RLE_NULL", "LONG_DECIMAL", "DICTIONARY_LONG_DECIMAL", "INTEGER", "DICTIONARY_INTEGER", "SMALLINT", "DICTIONARY_SMALLINT", "BOOLEAN", "DICTIONARY_BOOLEAN", "VARCHAR", "DICTIONARY_VARCHAR", "ARRAY_BIGINT", "ARRAY_VARCHAR", "ARRAY_ARRAY_BIGINT", "MAP_BIGINT_BIGINT", "MAP_BIGINT_MAP_BIGINT_BIGINT", "ROW_BIGINT_BIGINT", "ROW_ARRAY_BIGINT_ARRAY_BIGINT", "ROW_RLE_BIGINT_BIGINT"})
        private TestType type = TestType.BIGINT;

        @Param({"0", "0.2"})
        private float nullRate = 0.2f;

        /* loaded from: input_file:io/trino/operator/output/BenchmarkPartitionedOutputOperator$BenchmarkData$TestType.class */
        public enum TestType {
            BIGINT(BigintType.BIGINT, 5000),
            BIGINT_PARTITION_CHANNEL_SKEWED(BigintType.BIGINT, 5000, (list, i, f) -> {
                return BenchmarkData.page(i, list.size(), () -> {
                    return BlockAssertions.createRandomBlockForType(BigintType.BIGINT, i, f);
                }, BlockAssertions.createRandomLongsBlock(i, 2));
            }),
            DICTIONARY_BIGINT(BigintType.BIGINT, 5000, PageTestUtils::createRandomDictionaryPage),
            RLE_BIGINT(BigintType.BIGINT, 3000, PageTestUtils::createRandomRlePage),
            BIGINT_PARTITION_CHANNEL_20_PERCENT(BigintType.BIGINT, 3000, (list2, i2, f2) -> {
                return BenchmarkData.page(i2, list2.size(), () -> {
                    return BlockAssertions.createRandomBlockForType(BigintType.BIGINT, i2, f2);
                }, BlockAssertions.createLongsBlock((Iterable<Long>) LongStream.range(0L, i2).mapToObj(j -> {
                    return Long.valueOf(j % (i2 / 5));
                }).collect(ImmutableList.toImmutableList())));
            }),
            BIGINT_PARTITION_CHANNEL_DICTIONARY_20_PERCENT(BigintType.BIGINT, 3000, (list3, i3, f3) -> {
                return createDictionaryPartitionChannelPage(list3, i3, f3, i3 / 5);
            }),
            BIGINT_PARTITION_CHANNEL_DICTIONARY_50_PERCENT(BigintType.BIGINT, 3000, (list4, i4, f4) -> {
                return createDictionaryPartitionChannelPage(list4, i4, f4, i4 / 2);
            }),
            BIGINT_PARTITION_CHANNEL_DICTIONARY_80_PERCENT(BigintType.BIGINT, 3000, (list5, i5, f5) -> {
                return createDictionaryPartitionChannelPage(list5, i5, f5, (int) (i5 * 0.8d));
            }),
            BIGINT_PARTITION_CHANNEL_DICTIONARY_100_PERCENT(BigintType.BIGINT, 3000, (list6, i6, f6) -> {
                return createDictionaryPartitionChannelPage(list6, i6, f6, i6);
            }),
            BIGINT_PARTITION_CHANNEL_DICTIONARY_100_PERCENT_MINUS_1(BigintType.BIGINT, 3000, (list7, i7, f7) -> {
                return createDictionaryPartitionChannelPage(list7, i7, f7, i7 - 1);
            }),
            BIGINT_PARTITION_CHANNEL_RLE(BigintType.BIGINT, 5000, (list8, i8, f8) -> {
                return BenchmarkData.page(i8, list8.size(), () -> {
                    return BlockAssertions.createRandomBlockForType(BigintType.BIGINT, i8, f8);
                }, BlockAssertions.createRepeatedValuesBlock(42L, i8));
            }),
            BIGINT_PARTITION_CHANNEL_RLE_NULL(BigintType.BIGINT, 20, (list9, i9, f9) -> {
                return BenchmarkData.page(i9, list9.size(), () -> {
                    return BlockAssertions.createRandomBlockForType(BigintType.BIGINT, i9, f9);
                }, RunLengthEncodedBlock.create(BlockAssertions.createLongsBlock((Long) null), i9));
            }),
            LONG_DECIMAL(DecimalType.createDecimalType(19), 5000),
            DICTIONARY_LONG_DECIMAL(DecimalType.createDecimalType(19), 5000, PageTestUtils::createRandomDictionaryPage),
            INTEGER(IntegerType.INTEGER, 5000),
            DICTIONARY_INTEGER(IntegerType.INTEGER, 5000, PageTestUtils::createRandomDictionaryPage),
            SMALLINT(SmallintType.SMALLINT, 5000),
            DICTIONARY_SMALLINT(SmallintType.SMALLINT, 5000, PageTestUtils::createRandomDictionaryPage),
            BOOLEAN(BooleanType.BOOLEAN, 5000),
            DICTIONARY_BOOLEAN(BooleanType.BOOLEAN, 5000, PageTestUtils::createRandomDictionaryPage),
            VARCHAR(VarcharType.VARCHAR, 5000),
            DICTIONARY_VARCHAR(VarcharType.VARCHAR, 5000, PageTestUtils::createRandomDictionaryPage),
            ARRAY_BIGINT(new ArrayType(BigintType.BIGINT), 1000),
            ARRAY_VARCHAR(new ArrayType(VarcharType.VARCHAR), 1000),
            ARRAY_ARRAY_BIGINT(new ArrayType(new ArrayType(BigintType.BIGINT)), 1000),
            MAP_BIGINT_BIGINT(BenchmarkPartitionedOutputOperator.createMapType(BigintType.BIGINT, BigintType.BIGINT), 1000),
            MAP_BIGINT_MAP_BIGINT_BIGINT(BenchmarkPartitionedOutputOperator.createMapType(BigintType.BIGINT, BenchmarkPartitionedOutputOperator.createMapType(BigintType.BIGINT, BigintType.BIGINT)), 1000),
            ROW_BIGINT_BIGINT(BenchmarkPartitionedOutputOperator.rowTypeWithDefaultFieldNames(ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT)), 1000),
            ROW_ARRAY_BIGINT_ARRAY_BIGINT(BenchmarkPartitionedOutputOperator.rowTypeWithDefaultFieldNames(ImmutableList.of(new ArrayType(BigintType.BIGINT), new ArrayType(BigintType.BIGINT))), 1000),
            ROW_RLE_BIGINT_BIGINT(BenchmarkPartitionedOutputOperator.rowTypeWithDefaultFieldNames(ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT)), 1000, (list10, i10, f10) -> {
                return PageTestUtils.createPage(list10, i10, Optional.of(ImmutableList.of(0)), (List) list10.stream().map(type -> {
                    boolean[] zArr = null;
                    int i10 = 0;
                    if (f10 > 0.0f) {
                        zArr = new boolean[i10];
                        Set<Integer> chooseNullPositions = BlockAssertions.chooseNullPositions(i10, f10);
                        Iterator<Integer> it = chooseNullPositions.iterator();
                        while (it.hasNext()) {
                            zArr[it.next().intValue()] = true;
                        }
                        i10 = chooseNullPositions.size();
                    }
                    int i11 = i10 - i10;
                    return RowBlock.fromFieldBlocks(i10, Optional.ofNullable(zArr), new Block[]{RunLengthEncodedBlock.create(BlockAssertions.createLongsBlock(-65128734213L), i11), BlockAssertions.createRandomLongsBlock(i11, f10)});
                }).collect(ImmutableList.toImmutableList()));
            });

            private final Type type;
            private final int pageCount;
            private final PageGenerator pageGenerator;

            /* JADX INFO: Access modifiers changed from: package-private */
            /* loaded from: input_file:io/trino/operator/output/BenchmarkPartitionedOutputOperator$BenchmarkData$TestType$PageGenerator.class */
            public interface PageGenerator {
                Page createPage(List<Type> list, int i, float f);
            }

            TestType(Type type, int i) {
                this(type, i, PageTestUtils::createRandomPage);
            }

            TestType(Type type, int i, PageGenerator pageGenerator) {
                this.type = (Type) Objects.requireNonNull(type, "type is null");
                this.pageCount = i;
                this.pageGenerator = (PageGenerator) Objects.requireNonNull(pageGenerator, "pageGenerator is null");
            }

            public PageGenerator getPageGenerator() {
                return this.pageGenerator;
            }

            public int getPageCount() {
                return this.pageCount;
            }

            public OptionalInt getNullChannel() {
                return OptionalInt.empty();
            }

            public List<Type> getTypes(int i) {
                return Collections.nCopies(i, this.type);
            }

            /* JADX INFO: Access modifiers changed from: private */
            public static Page createDictionaryPartitionChannelPage(List<Type> list, int i, float f, int i2) {
                return BenchmarkData.page(i, list.size(), () -> {
                    return BlockAssertions.createRandomBlockForType((Type) list.get(0), i, f);
                }, BlockAssertions.createLongDictionaryBlock(0, i, i2));
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/operator/output/BenchmarkPartitionedOutputOperator$BenchmarkData$TestingPartitionedOutputBuffer.class */
        public static class TestingPartitionedOutputBuffer extends PartitionedOutputBuffer {
            private final Blackhole blackhole;

            public TestingPartitionedOutputBuffer(String str, OutputBufferStateMachine outputBufferStateMachine, PipelinedOutputBuffers pipelinedOutputBuffers, DataSize dataSize, Supplier<LocalMemoryContext> supplier, Executor executor, Blackhole blackhole) {
                super(str, outputBufferStateMachine, pipelinedOutputBuffers, dataSize, supplier, executor);
                this.blackhole = blackhole;
            }

            public void enqueue(int i, List<Slice> list) {
                if (this.blackhole != null) {
                    this.blackhole.consume(list);
                }
            }
        }

        public int getPageCount() {
            return this.pageCount;
        }

        public void setPageCount(int i) {
            this.pageCount = i;
        }

        public void setPartitionCount(int i) {
            this.partitionCount = i;
        }

        public void setPositionCount(int i) {
            this.positionCount = i;
        }

        public void setType(TestType testType) {
            this.type = (TestType) Objects.requireNonNull(testType, "type is null");
        }

        public Page getDataPage() {
            return this.dataPage;
        }

        @Setup
        public void setup(Blackhole blackhole) {
            setupData(blackhole);
            BenchmarkPartitionedOutputOperator.pollute();
        }

        private void setupData(Blackhole blackhole) {
            this.blackhole = blackhole;
            this.types = this.type.getTypes(this.channelCount);
            this.dataPage = this.type.getPageGenerator().createPage(this.types, this.positionCount, this.nullRate);
            this.pageCount = this.type.getPageCount();
            this.nullChannel = this.type.getNullChannel();
            this.types = ImmutableList.builder().addAll(this.types).add(BigintType.BIGINT).build();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static Page page(int i, int i2, Supplier<Block> supplier, Block block) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i3 = 0; i3 < i2; i3++) {
                builder.add(supplier.get());
            }
            builder.add(block);
            return new Page(i, (Block[]) builder.build().toArray(new Block[0]));
        }

        private PartitionedOutputBuffer createPartitionedOutputBuffer() {
            PipelinedOutputBuffers createInitial = PipelinedOutputBuffers.createInitial(PipelinedOutputBuffers.BufferType.PARTITIONED);
            for (int i = 0; i < this.partitionCount; i++) {
                createInitial = createInitial.withBuffer(new PipelinedOutputBuffers.OutputBufferId(i), i);
            }
            return createPartitionedBuffer(createInitial.withNoMoreBufferIds(), DataSize.of(Long.MAX_VALUE, DataSize.Unit.BYTE));
        }

        private PartitionedOutputOperator createPartitionedOutputOperator() {
            BucketPartitionFunction bucketPartitionFunction = new BucketPartitionFunction(new HashBucketFunction(new PrecomputedHashGenerator(0), this.partitionCount), IntStream.range(0, this.partitionCount).toArray());
            return new PartitionedOutputOperator.PartitionedOutputFactory(bucketPartitionFunction, ImmutableList.of(Integer.valueOf(this.types.size() - 1)), ImmutableList.of(Optional.empty()), false, OptionalInt.empty(), createPartitionedOutputBuffer(), MAX_PARTITION_BUFFER_SIZE, BenchmarkPartitionedOutputOperator.POSITIONS_APPENDER_FACTORY, Optional.empty(), AggregatedMemoryContext.newSimpleAggregatedMemoryContext(), 0, Optional.empty()).createOutputOperator(0, new PlanNodeId("plan-node-0"), this.types, Function.identity(), new PagesSerdeFactory(new TestingBlockEncodingSerde(), this.enableCompression)).createOperator(createDriverContext());
        }

        private DriverContext createDriverContext() {
            return TestingTaskContext.builder(EXECUTOR, SCHEDULER, SessionTestUtils.TEST_SESSION).build().addPipelineContext(0, true, true, false).addDriverContext();
        }

        private TestingPartitionedOutputBuffer createPartitionedBuffer(PipelinedOutputBuffers pipelinedOutputBuffers, DataSize dataSize) {
            return new TestingPartitionedOutputBuffer("task-instance-id", new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), SCHEDULER), pipelinedOutputBuffers, dataSize, () -> {
                return new SimpleLocalMemoryContext(AggregatedMemoryContext.newSimpleAggregatedMemoryContext(), "test");
            }, SCHEDULER, this.blackhole);
        }
    }

    @Benchmark
    public void addPage(BenchmarkData benchmarkData) {
        PartitionedOutputOperator createPartitionedOutputOperator = benchmarkData.createPartitionedOutputOperator();
        for (int i = 0; i < benchmarkData.getPageCount(); i++) {
            createPartitionedOutputOperator.addInput(benchmarkData.getDataPage());
        }
        createPartitionedOutputOperator.finish();
    }

    @Test
    public void verifyAddPage() {
        BenchmarkData benchmarkData = new BenchmarkData();
        benchmarkData.setup(null);
        new BenchmarkPartitionedOutputOperator().addPage(benchmarkData);
    }

    private static RowType rowTypeWithDefaultFieldNames(List<Type> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new RowType.Field(Optional.of("field" + i), list.get(i)));
        }
        return RowType.from(Collections.unmodifiableList(arrayList));
    }

    private static MapType createMapType(Type type, Type type2) {
        return new MapType(type, type2, new TypeOperators());
    }

    private static void pollute() {
        try {
            List of = List.of(BenchmarkData.TestType.BIGINT, BenchmarkData.TestType.DICTIONARY_BIGINT, BenchmarkData.TestType.RLE_BIGINT, BenchmarkData.TestType.LONG_DECIMAL, BenchmarkData.TestType.INTEGER, BenchmarkData.TestType.SMALLINT, BenchmarkData.TestType.BOOLEAN, BenchmarkData.TestType.VARCHAR, BenchmarkData.TestType.ARRAY_BIGINT);
            BenchmarkPartitionedOutputOperator benchmarkPartitionedOutputOperator = new BenchmarkPartitionedOutputOperator();
            of.forEach(testType -> {
                BenchmarkData benchmarkData = new BenchmarkData();
                benchmarkData.setType(testType);
                benchmarkData.setupData(null);
                benchmarkData.setPageCount(1);
                benchmarkPartitionedOutputOperator.addPage(benchmarkData);
                BenchmarkData benchmarkData2 = new BenchmarkData();
                benchmarkData2.setType(testType);
                benchmarkData2.setPartitionCount(256);
                benchmarkData2.setPositionCount(256);
                benchmarkData2.setupData(null);
                benchmarkData2.setPageCount(50);
                benchmarkPartitionedOutputOperator.addPage(benchmarkData2);
            });
        } catch (Throwable th) {
            throw new RuntimeException(th);
        }
    }

    public static void main(String[] strArr) throws Exception {
        Benchmarks.benchmark(BenchmarkPartitionedOutputOperator.class).withOptions(chainedOptionsBuilder -> {
            chainedOptionsBuilder.jvmArgs(new String[]{"-Xmx16g"});
        }).run();
    }
}
