package io.trino.operator.join;

import io.airlift.slice.SizeOf;
import io.airlift.units.DataSize;
import io.trino.operator.PagesHashStrategy;
import io.trino.operator.SyntheticAddress;
import io.trino.operator.join.PositionLinks;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.util.HashCollisionsEstimator;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.util.Arrays;
import java.util.Objects;
import org.openjdk.jol.info.ClassLayout;

/* loaded from: input_file:io/trino/operator/join/PagesHash.class */
public final class PagesHash {
    private static final int INSTANCE_SIZE = ClassLayout.parseClass(PagesHash.class).instanceSize();
    private static final DataSize CACHE_SIZE = DataSize.of(128, DataSize.Unit.KILOBYTE);
    private final LongArrayList addresses;
    private final PagesHashStrategy pagesHashStrategy;
    private final int channelCount;
    private final int mask;
    private final int[] key;
    private final long size;
    private final byte[] positionToHashes;
    private final long hashCollisions;
    private final double expectedHashCollisions;

    public PagesHash(LongArrayList longArrayList, PagesHashStrategy pagesHashStrategy, PositionLinks.FactoryBuilder factoryBuilder) {
        this.addresses = (LongArrayList) Objects.requireNonNull(longArrayList, "addresses is null");
        this.pagesHashStrategy = (PagesHashStrategy) Objects.requireNonNull(pagesHashStrategy, "pagesHashStrategy is null");
        this.channelCount = pagesHashStrategy.getChannelCount();
        int arraySize = HashCommon.arraySize(longArrayList.size(), 0.75f);
        this.mask = arraySize - 1;
        this.key = new int[arraySize];
        Arrays.fill(this.key, -1);
        this.positionToHashes = new byte[longArrayList.size()];
        int min = Math.min(longArrayList.size() + 1, ((int) CACHE_SIZE.toBytes()) / 32);
        long[] jArr = new long[min];
        long j = 0;
        for (int i = 0; i * min <= longArrayList.size(); i++) {
            int i2 = i * min;
            int min2 = Math.min((i + 1) * min, longArrayList.size()) - i2;
            for (int i3 = 0; i3 < min2; i3++) {
                int i4 = i3 + i2;
                long readHashPosition = readHashPosition(i4);
                jArr[i3] = readHashPosition;
                this.positionToHashes[i4] = (byte) readHashPosition;
            }
            for (int i5 = 0; i5 < min2; i5++) {
                int i6 = i5 + i2;
                if (!isPositionNull(i6)) {
                    long j2 = jArr[i5];
                    int hashPosition = getHashPosition(j2, this.mask);
                    while (true) {
                        if (this.key[hashPosition] == -1) {
                            break;
                        }
                        int i7 = this.key[hashPosition];
                        if (((byte) j2) == this.positionToHashes[i7] && positionEqualsPositionIgnoreNulls(i7, i6)) {
                            i6 = factoryBuilder.link(i6, i7);
                            break;
                        } else {
                            hashPosition = (hashPosition + 1) & this.mask;
                            j++;
                        }
                    }
                    this.key[hashPosition] = i6;
                }
            }
        }
        this.size = SizeOf.sizeOf(longArrayList.elements()) + pagesHashStrategy.getSizeInBytes() + SizeOf.sizeOf(this.key) + SizeOf.sizeOf(this.positionToHashes);
        this.hashCollisions = j;
        this.expectedHashCollisions = HashCollisionsEstimator.estimateNumberOfHashCollisions(longArrayList.size(), arraySize);
    }

    public final int getChannelCount() {
        return this.channelCount;
    }

    public int getPositionCount() {
        return this.addresses.size();
    }

    public long getInMemorySizeInBytes() {
        return INSTANCE_SIZE + this.size;
    }

    public long getHashCollisions() {
        return this.hashCollisions;
    }

    public double getExpectedHashCollisions() {
        return this.expectedHashCollisions;
    }

    public int getAddressIndex(int i, Page page) {
        return getAddressIndex(i, page, this.pagesHashStrategy.hashRow(i, page));
    }

    public int getAddressIndex(int i, Page page, long j) {
        int hashPosition = getHashPosition(j, this.mask);
        while (true) {
            int i2 = hashPosition;
            if (this.key[i2] == -1) {
                return -1;
            }
            if (positionEqualsCurrentRowIgnoreNulls(this.key[i2], (byte) j, i, page)) {
                return this.key[i2];
            }
            hashPosition = (i2 + 1) & this.mask;
        }
    }

    public void appendTo(long j, PageBuilder pageBuilder, int i) {
        long j2 = this.addresses.getLong(Math.toIntExact(j));
        this.pagesHashStrategy.appendTo(SyntheticAddress.decodeSliceIndex(j2), SyntheticAddress.decodePosition(j2), pageBuilder, i);
    }

    private boolean isPositionNull(int i) {
        long j = this.addresses.getLong(i);
        return this.pagesHashStrategy.isPositionNull(SyntheticAddress.decodeSliceIndex(j), SyntheticAddress.decodePosition(j));
    }

    private long readHashPosition(int i) {
        long j = this.addresses.getLong(i);
        return this.pagesHashStrategy.hashPosition(SyntheticAddress.decodeSliceIndex(j), SyntheticAddress.decodePosition(j));
    }

    private boolean positionEqualsCurrentRowIgnoreNulls(int i, byte b, int i2, Page page) {
        if (this.positionToHashes[i] != b) {
            return false;
        }
        long j = this.addresses.getLong(i);
        return this.pagesHashStrategy.positionEqualsRowIgnoreNulls(SyntheticAddress.decodeSliceIndex(j), SyntheticAddress.decodePosition(j), i2, page);
    }

    private boolean positionEqualsPositionIgnoreNulls(int i, int i2) {
        long j = this.addresses.getLong(i);
        int decodeSliceIndex = SyntheticAddress.decodeSliceIndex(j);
        int decodePosition = SyntheticAddress.decodePosition(j);
        long j2 = this.addresses.getLong(i2);
        return this.pagesHashStrategy.positionEqualsPositionIgnoreNulls(decodeSliceIndex, decodePosition, SyntheticAddress.decodeSliceIndex(j2), SyntheticAddress.decodePosition(j2));
    }

    private static int getHashPosition(long j, long j2) {
        long j3 = (j ^ (j >>> 33)) * (-49064778989728563L);
        long j4 = (j3 ^ (j3 >>> 33)) * (-4265267296055464877L);
        return (int) ((j4 ^ (j4 >>> 33)) & j2);
    }
}
