/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.util.collection.unsafe.sort;

import com.google.common.primitives.Ints;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;

public class RadixSort {
    public static int sort(LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) {
        assert (startByteIndex >= 0) : "startByteIndex (" + startByteIndex + ") should >= 0";
        assert (endByteIndex <= 7) : "endByteIndex (" + endByteIndex + ") should <= 7";
        assert (endByteIndex > startByteIndex);
        assert (numRecords * 2L <= array.size());
        long inIndex = 0L;
        long outIndex = numRecords;
        if (numRecords > 0L) {
            long[][] counts = RadixSort.getCounts(array, numRecords, startByteIndex, endByteIndex);
            for (int i = startByteIndex; i <= endByteIndex; ++i) {
                if (counts[i] == null) continue;
                RadixSort.sortAtByte(array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex);
                long tmp = inIndex;
                inIndex = outIndex;
                outIndex = tmp;
            }
        }
        return Ints.checkedCast((long)inIndex);
    }

    private static void sortAtByte(LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) {
        assert (counts.length == 256);
        long[] offsets = RadixSort.transformCountsToOffsets(counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8L, desc, signed);
        Object baseObject = array.getBaseObject();
        long baseOffset = array.getBaseOffset() + inIndex * 8L;
        long maxOffset = baseOffset + numRecords * 8L;
        for (long offset = baseOffset; offset < maxOffset; offset += 8L) {
            long value2 = Platform.getLong((Object)baseObject, (long)offset);
            int bucket = (int)(value2 >>> byteIdx * 8 & 0xFFL);
            Platform.putLong((Object)baseObject, (long)offsets[bucket], (long)value2);
            int n = bucket;
            offsets[n] = offsets[n] + 8L;
        }
    }

    private static long[][] getCounts(LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
        long[][] counts = new long[8][];
        long bitwiseMax = 0L;
        long bitwiseMin = -1L;
        long maxOffset = array.getBaseOffset() + numRecords * 8L;
        Object baseObject = array.getBaseObject();
        for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8L) {
            long value2 = Platform.getLong((Object)baseObject, (long)offset);
            bitwiseMax |= value2;
            bitwiseMin &= value2;
        }
        long bitsChanged = bitwiseMin ^ bitwiseMax;
        for (int i = startByteIndex; i <= endByteIndex; ++i) {
            if ((bitsChanged >>> i * 8 & 0xFFL) == 0L) continue;
            counts[i] = new long[256];
            for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8L) {
                long[] lArray = counts[i];
                int n = (int)(Platform.getLong((Object)baseObject, (long)offset) >>> i * 8 & 0xFFL);
                lArray[n] = lArray[n] + 1L;
            }
        }
        return counts;
    }

    private static long[] transformCountsToOffsets(long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) {
        int start2;
        assert (counts.length == 256);
        int n = start2 = signed ? 128 : 0;
        if (desc) {
            long pos = numRecords;
            for (int i = start2; i < start2 + 256; ++i) {
                counts[i & 0xFF] = outputOffset + (pos -= counts[i & 0xFF]) * bytesPerRecord;
            }
        } else {
            long pos = 0L;
            for (int i = start2; i < start2 + 256; ++i) {
                long tmp = counts[i & 0xFF];
                counts[i & 0xFF] = outputOffset + pos * bytesPerRecord;
                pos += tmp;
            }
        }
        return counts;
    }

    public static int sortKeyPrefixArray(LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) {
        assert (startByteIndex >= 0) : "startByteIndex (" + startByteIndex + ") should >= 0";
        assert (endByteIndex <= 7) : "endByteIndex (" + endByteIndex + ") should <= 7";
        assert (endByteIndex > startByteIndex);
        assert (numRecords * 4L <= array.size());
        long inIndex = startIndex;
        long outIndex = startIndex + numRecords * 2L;
        if (numRecords > 0L) {
            long[][] counts = RadixSort.getKeyPrefixArrayCounts(array, startIndex, numRecords, startByteIndex, endByteIndex);
            for (int i = startByteIndex; i <= endByteIndex; ++i) {
                if (counts[i] == null) continue;
                RadixSort.sortKeyPrefixArrayAtByte(array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex);
                long tmp = inIndex;
                inIndex = outIndex;
                outIndex = tmp;
            }
        }
        return Ints.checkedCast((long)inIndex);
    }

    private static long[][] getKeyPrefixArrayCounts(LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
        long[][] counts = new long[8][];
        long bitwiseMax = 0L;
        long bitwiseMin = -1L;
        long baseOffset = array.getBaseOffset() + startIndex * 8L;
        long limit = baseOffset + numRecords * 16L;
        Object baseObject = array.getBaseObject();
        for (long offset = baseOffset; offset < limit; offset += 16L) {
            long value2 = Platform.getLong((Object)baseObject, (long)(offset + 8L));
            bitwiseMax |= value2;
            bitwiseMin &= value2;
        }
        long bitsChanged = bitwiseMin ^ bitwiseMax;
        for (int i = startByteIndex; i <= endByteIndex; ++i) {
            if ((bitsChanged >>> i * 8 & 0xFFL) == 0L) continue;
            counts[i] = new long[256];
            for (long offset = baseOffset; offset < limit; offset += 16L) {
                long[] lArray = counts[i];
                int n = (int)(Platform.getLong((Object)baseObject, (long)(offset + 8L)) >>> i * 8 & 0xFFL);
                lArray[n] = lArray[n] + 1L;
            }
        }
        return counts;
    }

    private static void sortKeyPrefixArrayAtByte(LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) {
        assert (counts.length == 256);
        long[] offsets = RadixSort.transformCountsToOffsets(counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16L, desc, signed);
        Object baseObject = array.getBaseObject();
        long baseOffset = array.getBaseOffset() + inIndex * 8L;
        long maxOffset = baseOffset + numRecords * 16L;
        for (long offset = baseOffset; offset < maxOffset; offset += 16L) {
            long key = Platform.getLong((Object)baseObject, (long)offset);
            long prefix = Platform.getLong((Object)baseObject, (long)(offset + 8L));
            int bucket = (int)(prefix >>> byteIdx * 8 & 0xFFL);
            long dest = offsets[bucket];
            Platform.putLong((Object)baseObject, (long)dest, (long)key);
            Platform.putLong((Object)baseObject, (long)(dest + 8L), (long)prefix);
            int n = bucket;
            offsets[n] = offsets[n] + 16L;
        }
    }
}

