package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import io.trino.array.LongBigArray;
import io.trino.jmh.Benchmarks;
import io.trino.operator.BenchmarkWindowOperator;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.type.AbstractLongType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.VarcharType;
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.TestTableScanNodePartitioning;
import io.trino.type.BlockTypeOperators;
import it.unimi.dsi.fastutil.HashCommon;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.profile.GCProfiler;
import org.openjdk.jmh.runner.RunnerException;

@Warmup(iterations = TestTableScanNodePartitioning.BUCKET_COUNT, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Measurement(iterations = TestTableScanNodePartitioning.BUCKET_COUNT, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(BenchmarkWindowOperator.Context.NUMBER_OF_GROUP_COLUMNS)
@BenchmarkMode({Mode.AverageTime})
/* loaded from: input_file:io/trino/operator/BenchmarkGroupByHash.class */
public class BenchmarkGroupByHash {
    private static final int POSITIONS = 10000000;
    private static final int EXPECTED_SIZE = 10000;
    private static final String GROUP_COUNT_STRING = "3000000";
    private static final int GROUP_COUNT = Integer.parseInt(GROUP_COUNT_STRING);
    private static final TypeOperators TYPE_OPERATORS = new TypeOperators();
    private static final BlockTypeOperators TYPE_OPERATOR_FACTORY = new BlockTypeOperators(TYPE_OPERATORS);

    @State(Scope.Thread)
    /* loaded from: input_file:io/trino/operator/BenchmarkGroupByHash$BaselinePagesData.class */
    public static class BaselinePagesData {

        @Param({"1"})
        private int channelCount = 1;

        @Param({"false"})
        private boolean hashEnabled;

        @Param({BenchmarkGroupByHash.GROUP_COUNT_STRING})
        private int groupCount;
        private List<Page> pages;

        @Setup
        public void setup() {
            this.pages = BenchmarkGroupByHash.createBigintPages(10000000, this.groupCount, this.channelCount, this.hashEnabled);
        }

        public List<Page> getPages() {
            return this.pages;
        }
    }

    @State(Scope.Thread)
    /* loaded from: input_file:io/trino/operator/BenchmarkGroupByHash$BenchmarkData.class */
    public static class BenchmarkData {

        @Param({"true", "false"})
        private boolean hashEnabled;
        private List<Page> pages;
        private Optional<Integer> hashChannel;
        private List<Type> types;
        private int[] channels;

        @Param({"1", "5", "10", "15", "20"})
        private int channelCount = 1;

        @Param({BenchmarkGroupByHash.GROUP_COUNT_STRING})
        private int groupCount = BenchmarkGroupByHash.GROUP_COUNT;

        @Param({"VARCHAR", "BIGINT"})
        private String dataType = "VARCHAR";

        @Setup
        public void setup() {
            String str = this.dataType;
            boolean z = -1;
            switch (str.hashCode()) {
                case 954596061:
                    if (str.equals("VARCHAR")) {
                        z = false;
                        break;
                    }
                    break;
                case 1959128815:
                    if (str.equals("BIGINT")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.types = Collections.nCopies(this.channelCount, VarcharType.VARCHAR);
                    this.pages = BenchmarkGroupByHash.createVarcharPages(10000000, this.groupCount, this.channelCount, this.hashEnabled);
                    break;
                case true:
                    this.types = Collections.nCopies(this.channelCount, BigintType.BIGINT);
                    this.pages = BenchmarkGroupByHash.createBigintPages(10000000, this.groupCount, this.channelCount, this.hashEnabled);
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported dataType");
            }
            this.hashChannel = this.hashEnabled ? Optional.of(Integer.valueOf(this.channelCount)) : Optional.empty();
            this.channels = new int[this.channelCount];
            for (int i = 0; i < this.channelCount; i++) {
                this.channels[i] = i;
            }
        }

        public List<Page> getPages() {
            return this.pages;
        }

        public Optional<Integer> getHashChannel() {
            return this.hashChannel;
        }

        public List<Type> getTypes() {
            return this.types;
        }

        public int[] getChannels() {
            return this.channels;
        }
    }

    @State(Scope.Thread)
    /* loaded from: input_file:io/trino/operator/BenchmarkGroupByHash$SingleChannelBenchmarkData.class */
    public static class SingleChannelBenchmarkData {

        @Param({"1"})
        private int channelCount = 1;

        @Param({"true", "false"})
        private boolean hashEnabled = true;
        private List<Page> pages;
        private List<Type> types;
        private int[] channels;

        @Setup
        public void setup() {
            this.pages = BenchmarkGroupByHash.createBigintPages(10000000, BenchmarkGroupByHash.GROUP_COUNT, this.channelCount, this.hashEnabled);
            this.types = Collections.nCopies(1, BigintType.BIGINT);
            this.channels = new int[1];
            for (int i = 0; i < 1; i++) {
                this.channels[i] = i;
            }
        }

        public List<Page> getPages() {
            return this.pages;
        }

        public List<Type> getTypes() {
            return this.types;
        }

        public boolean getHashEnabled() {
            return this.hashEnabled;
        }
    }

    @Benchmark
    @OperationsPerInvocation(10000000)
    public Object groupByHashPreCompute(BenchmarkData benchmarkData) {
        MultiChannelGroupByHash multiChannelGroupByHash = new MultiChannelGroupByHash(benchmarkData.getTypes(), benchmarkData.getChannels(), benchmarkData.getHashChannel(), 10000, false, getJoinCompiler(), TYPE_OPERATOR_FACTORY, UpdateMemory.NOOP);
        addInputPagesToHash(multiChannelGroupByHash, benchmarkData.getPages());
        ImmutableList.Builder builder = ImmutableList.builder();
        PageBuilder pageBuilder = new PageBuilder(multiChannelGroupByHash.getTypes());
        for (int i = 0; i < multiChannelGroupByHash.getGroupCount(); i++) {
            pageBuilder.declarePosition();
            multiChannelGroupByHash.appendValuesTo(i, pageBuilder, 0);
            if (pageBuilder.isFull()) {
                builder.add(pageBuilder.build());
                pageBuilder.reset();
            }
        }
        builder.add(pageBuilder.build());
        return pageBuilder.build();
    }

    @Benchmark
    @OperationsPerInvocation(10000000)
    public List<Page> benchmarkHashPosition(BenchmarkData benchmarkData) {
        InterpretedHashGenerator interpretedHashGenerator = new InterpretedHashGenerator(benchmarkData.getTypes(), benchmarkData.getChannels(), TYPE_OPERATOR_FACTORY);
        ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(benchmarkData.getPages().size());
        for (Page page : benchmarkData.getPages()) {
            long[] jArr = new long[page.getPositionCount()];
            for (int i = 0; i < page.getPositionCount(); i++) {
                jArr[i] = interpretedHashGenerator.hashPosition(i, page);
            }
            builderWithExpectedSize.add(page.appendColumn(new LongArrayBlock(page.getPositionCount(), Optional.empty(), jArr)));
        }
        return builderWithExpectedSize.build();
    }

    @Benchmark
    @OperationsPerInvocation(10000000)
    public Object addPagePreCompute(BenchmarkData benchmarkData) {
        MultiChannelGroupByHash multiChannelGroupByHash = new MultiChannelGroupByHash(benchmarkData.getTypes(), benchmarkData.getChannels(), benchmarkData.getHashChannel(), 10000, false, getJoinCompiler(), TYPE_OPERATOR_FACTORY, UpdateMemory.NOOP);
        addInputPagesToHash(multiChannelGroupByHash, benchmarkData.getPages());
        ImmutableList.Builder builder = ImmutableList.builder();
        PageBuilder pageBuilder = new PageBuilder(multiChannelGroupByHash.getTypes());
        for (int i = 0; i < multiChannelGroupByHash.getGroupCount(); i++) {
            pageBuilder.declarePosition();
            multiChannelGroupByHash.appendValuesTo(i, pageBuilder, 0);
            if (pageBuilder.isFull()) {
                builder.add(pageBuilder.build());
                pageBuilder.reset();
            }
        }
        builder.add(pageBuilder.build());
        return pageBuilder.build();
    }

    @Benchmark
    @OperationsPerInvocation(10000000)
    public Object bigintGroupByHash(SingleChannelBenchmarkData singleChannelBenchmarkData) {
        BigintGroupByHash bigintGroupByHash = new BigintGroupByHash(0, singleChannelBenchmarkData.getHashEnabled(), 10000, UpdateMemory.NOOP);
        addInputPagesToHash(bigintGroupByHash, singleChannelBenchmarkData.getPages());
        ImmutableList.Builder builder = ImmutableList.builder();
        PageBuilder pageBuilder = new PageBuilder(bigintGroupByHash.getTypes());
        for (int i = 0; i < bigintGroupByHash.getGroupCount(); i++) {
            pageBuilder.declarePosition();
            bigintGroupByHash.appendValuesTo(i, pageBuilder, 0);
            if (pageBuilder.isFull()) {
                builder.add(pageBuilder.build());
                pageBuilder.reset();
            }
        }
        builder.add(pageBuilder.build());
        return pageBuilder.build();
    }

    @Benchmark
    @OperationsPerInvocation(10000000)
    public long baseline(BaselinePagesData baselinePagesData) {
        int arraySize = HashCommon.arraySize(GROUP_COUNT, 0.9f);
        int i = arraySize - 1;
        long[] jArr = new long[arraySize];
        Arrays.fill(jArr, -1L);
        long j = 0;
        Iterator<Page> it = baselinePagesData.getPages().iterator();
        while (it.hasNext()) {
            Block block = it.next().getBlock(0);
            int positionCount = block.getPositionCount();
            for (int i2 = 0; i2 < positionCount; i2++) {
                long j2 = block.getLong(i2, 0);
                int i3 = (int) (j2 & i);
                while (jArr[i3] != -1 && jArr[i3] != j2) {
                    i3++;
                }
                if (jArr[i3] == -1) {
                    jArr[i3] = j2;
                    j++;
                }
            }
        }
        return j;
    }

    @Benchmark
    @OperationsPerInvocation(10000000)
    public long baselineBigArray(BaselinePagesData baselinePagesData) {
        int arraySize = HashCommon.arraySize(GROUP_COUNT, 0.9f);
        int i = arraySize - 1;
        LongBigArray longBigArray = new LongBigArray(-1L);
        longBigArray.ensureCapacity(arraySize);
        long j = 0;
        Iterator<Page> it = baselinePagesData.getPages().iterator();
        while (it.hasNext()) {
            Block block = it.next().getBlock(0);
            int positionCount = block.getPositionCount();
            for (int i2 = 0; i2 < positionCount; i2++) {
                long j2 = BigintType.BIGINT.getLong(block, i2);
                int hash = ((int) XxHash64.hash(j2)) & i;
                while (longBigArray.get(hash) != -1 && longBigArray.get(hash) != j2) {
                    hash++;
                }
                if (longBigArray.get(hash) == -1) {
                    longBigArray.set(hash, j2);
                    j++;
                }
            }
        }
        return j;
    }

    private static void addInputPagesToHash(GroupByHash groupByHash, List<Page> list) {
        Iterator<Page> it = list.iterator();
        while (it.hasNext()) {
            do {
            } while (!groupByHash.addPage(it.next()).process());
        }
    }

    private static List<Page> createBigintPages(int i, int i2, int i3, boolean z) {
        ImmutableList nCopies = Collections.nCopies(i3, BigintType.BIGINT);
        ImmutableList.Builder builder = ImmutableList.builder();
        if (z) {
            nCopies = ImmutableList.copyOf(Iterables.concat(nCopies, ImmutableList.of(BigintType.BIGINT)));
        }
        PageBuilder pageBuilder = new PageBuilder(nCopies);
        for (int i4 = 0; i4 < i; i4++) {
            int nextInt = ThreadLocalRandom.current().nextInt(i2);
            pageBuilder.declarePosition();
            for (int i5 = 0; i5 < i3; i5++) {
                BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(i5), nextInt);
            }
            if (z) {
                BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(i3), AbstractLongType.hash(nextInt));
            }
            if (pageBuilder.isFull()) {
                builder.add(pageBuilder.build());
                pageBuilder.reset();
            }
        }
        builder.add(pageBuilder.build());
        return builder.build();
    }

    private static List<Page> createVarcharPages(int i, int i2, int i3, boolean z) {
        ImmutableList nCopies = Collections.nCopies(i3, VarcharType.VARCHAR);
        ImmutableList.Builder builder = ImmutableList.builder();
        if (z) {
            nCopies = ImmutableList.copyOf(Iterables.concat(nCopies, ImmutableList.of(BigintType.BIGINT)));
        }
        PageBuilder pageBuilder = new PageBuilder(nCopies);
        for (int i4 = 0; i4 < i; i4++) {
            Slice wrappedBuffer = Slices.wrappedBuffer(ByteBuffer.allocate(4).putInt(ThreadLocalRandom.current().nextInt(i2)));
            pageBuilder.declarePosition();
            for (int i5 = 0; i5 < i3; i5++) {
                VarcharType.VARCHAR.writeSlice(pageBuilder.getBlockBuilder(i5), wrappedBuffer);
            }
            if (z) {
                BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(i3), XxHash64.hash(wrappedBuffer));
            }
            if (pageBuilder.isFull()) {
                builder.add(pageBuilder.build());
                pageBuilder.reset();
            }
        }
        builder.add(pageBuilder.build());
        return builder.build();
    }

    private static JoinCompiler getJoinCompiler() {
        return new JoinCompiler(TYPE_OPERATORS);
    }

    public static void main(String[] strArr) throws RunnerException {
        BenchmarkData benchmarkData = new BenchmarkData();
        benchmarkData.setup();
        new BenchmarkGroupByHash().groupByHashPreCompute(benchmarkData);
        new BenchmarkGroupByHash().addPagePreCompute(benchmarkData);
        SingleChannelBenchmarkData singleChannelBenchmarkData = new SingleChannelBenchmarkData();
        singleChannelBenchmarkData.setup();
        new BenchmarkGroupByHash().bigintGroupByHash(singleChannelBenchmarkData);
        Benchmarks.benchmark(BenchmarkGroupByHash.class).withOptions(chainedOptionsBuilder -> {
            chainedOptionsBuilder.addProfiler(GCProfiler.class).jvmArgs(new String[]{"-Xmx10g"});
        }).run();
    }
}
