package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import io.trino.array.LongBigArray;
import io.trino.util.HeapTraversal;
import io.trino.util.LongBigArrayFIFOQueue;
import java.util.Objects;
import java.util.function.LongConsumer;
import javax.annotation.Nullable;
import org.openjdk.jol.info.ClassLayout;

/* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator.class */
public class GroupedTopNRowNumberAccumulator {
    private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupedTopNRowNumberAccumulator.class).instanceSize();
    private static final long UNKNOWN_INDEX = -1;
    private final GroupIdToHeapBuffer groupIdToHeapBuffer = new GroupIdToHeapBuffer();
    private final HeapNodeBuffer heapNodeBuffer = new HeapNodeBuffer();
    private final HeapTraversal heapTraversal = new HeapTraversal();
    private final RowIdComparisonStrategy strategy;
    private final int topN;
    private final LongConsumer rowIdEvictionListener;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator$GroupIdToHeapBuffer.class */
    public static class GroupIdToHeapBuffer {
        private static final long INSTANCE_SIZE = ClassLayout.parseClass(GroupIdToHeapBuffer.class).instanceSize();
        private final LongBigArray heapIndexBuffer = new LongBigArray(-1);
        private final LongBigArray sizeBuffer = new LongBigArray(0);
        private long totalGroups;

        private GroupIdToHeapBuffer() {
        }

        public void allocateGroupIfNeeded(long j) {
            this.totalGroups = Math.max(j + 1, this.totalGroups);
            this.heapIndexBuffer.ensureCapacity(this.totalGroups);
            this.sizeBuffer.ensureCapacity(this.totalGroups);
        }

        public long getTotalGroups() {
            return this.totalGroups;
        }

        public long getHeapRootNodeIndex(long j) {
            return this.heapIndexBuffer.get(j);
        }

        public void setHeapRootNodeIndex(long j, long j2) {
            this.heapIndexBuffer.set(j, j2);
        }

        public long getHeapSize(long j) {
            return this.sizeBuffer.get(j);
        }

        public void setHeapSize(long j, long j2) {
            this.sizeBuffer.set(j, j2);
        }

        public void addHeapSize(long j, long j2) {
            this.sizeBuffer.add(j, j2);
        }

        public void incrementHeapSize(long j) {
            this.sizeBuffer.increment(j);
        }

        public long sizeOf() {
            return INSTANCE_SIZE + this.heapIndexBuffer.sizeOf() + this.sizeBuffer.sizeOf();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator$HeapNodeBuffer.class */
    public static class HeapNodeBuffer {
        private static final long INSTANCE_SIZE = ClassLayout.parseClass(HeapNodeBuffer.class).instanceSize();
        private static final int POSITIONS_PER_ENTRY = 3;
        private static final int LEFT_CHILD_HEAP_INDEX_OFFSET = 1;
        private static final int RIGHT_CHILD_HEAP_INDEX_OFFSET = 2;
        private final LongBigArray buffer = new LongBigArray();
        private final LongBigArrayFIFOQueue emptySlots = new LongBigArrayFIFOQueue();
        private long capacity;

        private HeapNodeBuffer() {
        }

        public long allocateNewNode(long j) {
            long j2;
            if (this.emptySlots.isEmpty()) {
                j2 = this.capacity;
                this.capacity++;
                this.buffer.ensureCapacity(this.capacity * 3);
            } else {
                j2 = this.emptySlots.dequeueLong();
            }
            setRowId(j2, j);
            setLeftChildHeapIndex(j2, -1L);
            setRightChildHeapIndex(j2, -1L);
            return j2;
        }

        public void deallocate(long j) {
            this.emptySlots.enqueue(j);
        }

        public long getActiveNodeCount() {
            return this.capacity - this.emptySlots.longSize();
        }

        public long getRowId(long j) {
            return this.buffer.get(j * 3);
        }

        public void setRowId(long j, long j2) {
            this.buffer.set(j * 3, j2);
        }

        public long getLeftChildHeapIndex(long j) {
            return this.buffer.get((j * 3) + 1);
        }

        public void setLeftChildHeapIndex(long j, long j2) {
            this.buffer.set((j * 3) + 1, j2);
        }

        public long getRightChildHeapIndex(long j) {
            return this.buffer.get((j * 3) + 2);
        }

        public void setRightChildHeapIndex(long j, long j2) {
            this.buffer.set((j * 3) + 2, j2);
        }

        public long sizeOf() {
            return INSTANCE_SIZE + this.buffer.sizeOf() + this.emptySlots.sizeOf();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator$IntegrityStats.class */
    public static class IntegrityStats {
        private final long maxDepth;
        private final long nodeCount;

        public IntegrityStats(long j, long j2) {
            this.maxDepth = j;
            this.nodeCount = j2;
        }

        public long getMaxDepth() {
            return this.maxDepth;
        }

        public long getNodeCount() {
            return this.nodeCount;
        }
    }

    public GroupedTopNRowNumberAccumulator(RowIdComparisonStrategy rowIdComparisonStrategy, int i, LongConsumer longConsumer) {
        this.strategy = (RowIdComparisonStrategy) Objects.requireNonNull(rowIdComparisonStrategy, "strategy is null");
        Preconditions.checkArgument(i > 0, "topN must be greater than zero");
        this.topN = i;
        this.rowIdEvictionListener = (LongConsumer) Objects.requireNonNull(longConsumer, "rowIdEvictionListener is null");
    }

    public long sizeOf() {
        return INSTANCE_SIZE + this.groupIdToHeapBuffer.sizeOf() + this.heapNodeBuffer.sizeOf() + this.heapTraversal.sizeOf();
    }

    public boolean add(long j, RowReference rowReference) {
        this.groupIdToHeapBuffer.allocateGroupIfNeeded(j);
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j);
        if (heapRootNodeIndex == -1 || calculateRootRowNumber(j) < this.topN) {
            heapInsert(j, rowReference.allocateRowId());
            return true;
        }
        if (rowReference.compareTo(this.strategy, this.heapNodeBuffer.getRowId(heapRootNodeIndex)) >= 0) {
            return false;
        }
        heapPopAndInsert(j, rowReference.allocateRowId(), this.rowIdEvictionListener);
        return true;
    }

    public long drainTo(long j, LongBigArray longBigArray) {
        long heapSize = this.groupIdToHeapBuffer.getHeapSize(j);
        longBigArray.ensureCapacity(heapSize);
        long j2 = heapSize;
        while (true) {
            long j3 = j2 - 1;
            if (j3 < 0) {
                return heapSize;
            }
            longBigArray.set(j3, peekRootRowId(j));
            heapPop(j, null);
            j2 = j3;
        }
    }

    private long calculateRootRowNumber(long j) {
        return this.groupIdToHeapBuffer.getHeapSize(j);
    }

    private long peekRootRowId(long j) {
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j);
        Preconditions.checkArgument(heapRootNodeIndex != -1, "No root to peek");
        return this.heapNodeBuffer.getRowId(heapRootNodeIndex);
    }

    private long getChildIndex(long j, HeapTraversal.Child child) {
        return child == HeapTraversal.Child.LEFT ? this.heapNodeBuffer.getLeftChildHeapIndex(j) : this.heapNodeBuffer.getRightChildHeapIndex(j);
    }

    private void setChildIndex(long j, HeapTraversal.Child child, long j2) {
        if (child == HeapTraversal.Child.LEFT) {
            this.heapNodeBuffer.setLeftChildHeapIndex(j, j2);
        } else {
            this.heapNodeBuffer.setRightChildHeapIndex(j, j2);
        }
    }

    private void heapPop(long j, @Nullable LongConsumer longConsumer) {
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j);
        Preconditions.checkArgument(heapRootNodeIndex != -1, "Group ID has an empty heap");
        long heapDetachLastInsertionLeaf = heapDetachLastInsertionLeaf(j);
        long rowId = this.heapNodeBuffer.getRowId(heapDetachLastInsertionLeaf);
        this.heapNodeBuffer.deallocate(heapDetachLastInsertionLeaf);
        if (heapDetachLastInsertionLeaf != heapRootNodeIndex) {
            heapPopAndInsert(j, rowId, longConsumer);
        } else if (longConsumer != null) {
            longConsumer.accept(rowId);
        }
    }

    private long heapDetachLastInsertionLeaf(long j) {
        long j2 = -1;
        HeapTraversal.Child child = null;
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j);
        this.heapTraversal.resetWithPathTo(this.groupIdToHeapBuffer.getHeapSize(j));
        while (!this.heapTraversal.isTarget()) {
            j2 = heapRootNodeIndex;
            child = this.heapTraversal.nextChild();
            heapRootNodeIndex = getChildIndex(heapRootNodeIndex, child);
            Verify.verify(heapRootNodeIndex != -1, "Target node must exist", new Object[0]);
        }
        if (j2 == -1) {
            this.groupIdToHeapBuffer.setHeapRootNodeIndex(j, -1L);
            this.groupIdToHeapBuffer.setHeapSize(j, 0L);
        } else {
            setChildIndex(j2, child, -1L);
            this.groupIdToHeapBuffer.addHeapSize(j, -1L);
        }
        return heapRootNodeIndex;
    }

    private void heapInsert(long j, long j2) {
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j);
        if (heapRootNodeIndex == -1) {
            this.groupIdToHeapBuffer.setHeapRootNodeIndex(j, this.heapNodeBuffer.allocateNewNode(j2));
            this.groupIdToHeapBuffer.setHeapSize(j, 1L);
            return;
        }
        long j3 = -1;
        HeapTraversal.Child child = null;
        long j4 = heapRootNodeIndex;
        this.heapTraversal.resetWithPathTo(this.groupIdToHeapBuffer.getHeapSize(j) + 1);
        while (!this.heapTraversal.isTarget()) {
            long rowId = this.heapNodeBuffer.getRowId(j4);
            if (this.strategy.compare(j2, rowId) > 0) {
                this.heapNodeBuffer.setRowId(j4, j2);
                j2 = rowId;
            }
            j3 = j4;
            child = this.heapTraversal.nextChild();
            j4 = getChildIndex(j4, child);
        }
        Verify.verify((j3 == -1 || child == null) ? false : true, "heap must have at least one node before starting traversal", new Object[0]);
        Verify.verify(j4 == -1, "New child shouldn't exist yet", new Object[0]);
        setChildIndex(j3, child, this.heapNodeBuffer.allocateNewNode(j2));
        this.groupIdToHeapBuffer.incrementHeapSize(j);
    }

    private void heapPopAndInsert(long j, long j2, @Nullable LongConsumer longConsumer) {
        long j3;
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j);
        Preconditions.checkState(heapRootNodeIndex != -1, "popAndInsert() requires at least a root node");
        long rowId = this.heapNodeBuffer.getRowId(heapRootNodeIndex);
        long j4 = heapRootNodeIndex;
        while (true) {
            j3 = j4;
            long leftChildHeapIndex = this.heapNodeBuffer.getLeftChildHeapIndex(j3);
            if (leftChildHeapIndex == -1) {
                break;
            }
            long rowId2 = this.heapNodeBuffer.getRowId(leftChildHeapIndex);
            long rightChildHeapIndex = this.heapNodeBuffer.getRightChildHeapIndex(j3);
            if (rightChildHeapIndex != -1) {
                long rowId3 = this.heapNodeBuffer.getRowId(rightChildHeapIndex);
                if (this.strategy.compare(rowId3, rowId2) > 0) {
                    leftChildHeapIndex = rightChildHeapIndex;
                    rowId2 = rowId3;
                }
            }
            if (this.strategy.compare(j2, rowId2) >= 0) {
                break;
            }
            this.heapNodeBuffer.setRowId(j3, rowId2);
            j4 = leftChildHeapIndex;
        }
        this.heapNodeBuffer.setRowId(j3, j2);
        if (longConsumer != null) {
            longConsumer.accept(rowId);
        }
    }

    @VisibleForTesting
    void verifyIntegrity() {
        long j = 0;
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= this.groupIdToHeapBuffer.getTotalGroups()) {
                break;
            }
            long heapSize = this.groupIdToHeapBuffer.getHeapSize(j3);
            long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(j3);
            Verify.verify(heapRootNodeIndex == -1 || calculateRootRowNumber(heapRootNodeIndex) <= ((long) this.topN), "Max heap has more values than needed", new Object[0]);
            IntegrityStats verifyHeapIntegrity = verifyHeapIntegrity(heapRootNodeIndex);
            Verify.verify(verifyHeapIntegrity.getNodeCount() == heapSize, "Recorded heap size does not match actual heap size", new Object[0]);
            j += verifyHeapIntegrity.getNodeCount();
            j2 = j3 + 1;
        }
        Verify.verify(j == this.heapNodeBuffer.getActiveNodeCount(), "Failed to deallocate some unused nodes", new Object[0]);
    }

    private IntegrityStats verifyHeapIntegrity(long j) {
        if (j == -1) {
            return new IntegrityStats(0L, 0L);
        }
        long rowId = this.heapNodeBuffer.getRowId(j);
        long leftChildHeapIndex = this.heapNodeBuffer.getLeftChildHeapIndex(j);
        long rightChildHeapIndex = this.heapNodeBuffer.getRightChildHeapIndex(j);
        if (leftChildHeapIndex != -1) {
            Verify.verify(this.strategy.compare(rowId, this.heapNodeBuffer.getRowId(leftChildHeapIndex)) >= 0, "Max heap invariant violated", new Object[0]);
        }
        if (rightChildHeapIndex != -1) {
            Verify.verify(leftChildHeapIndex != -1, "Left should always be inserted before right", new Object[0]);
            Verify.verify(this.strategy.compare(rowId, this.heapNodeBuffer.getRowId(rightChildHeapIndex)) >= 0, "Max heap invariant violated", new Object[0]);
        }
        IntegrityStats verifyHeapIntegrity = verifyHeapIntegrity(leftChildHeapIndex);
        IntegrityStats verifyHeapIntegrity2 = verifyHeapIntegrity(rightChildHeapIndex);
        Verify.verify(Math.abs(verifyHeapIntegrity.getMaxDepth() - verifyHeapIntegrity2.getMaxDepth()) <= 1, "Heap not balanced", new Object[0]);
        return new IntegrityStats(Math.max(verifyHeapIntegrity.getMaxDepth(), verifyHeapIntegrity2.getMaxDepth()) + 1, verifyHeapIntegrity.getNodeCount() + verifyHeapIntegrity2.getNodeCount() + 1);
    }
}
