/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.core.utils.paged;

import java.io.Serializable;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.IntStream;
import org.apache.commons.lang3.mutable.MutableLong;
import org.eclipse.collections.api.block.procedure.primitive.LongLongProcedure;
import org.eclipse.collections.api.map.primitive.LongLongMap;
import org.eclipse.collections.api.map.primitive.MutableLongLongMap;
import org.eclipse.collections.impl.SpreadFunctions;
import org.eclipse.collections.impl.collection.mutable.AbstractMultiReaderMutableCollection;
import org.eclipse.collections.impl.factory.primitive.LongLongMaps;
import org.neo4j.gds.core.loading.IdMapAllocator;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.mem.BitUtil;
import org.neo4j.gds.utils.CloseableThreadLocal;

public final class ShardedLongLongMap {
    private final HugeLongArray internalNodeMapping;
    private final LongLongMap[] originalNodeMappingShards;
    private final int shardShift;
    private final int shardMask;
    private final long maxOriginalId;

    public static Builder builder(int concurrency) {
        return new Builder(concurrency);
    }

    public static BatchedBuilder batchedBuilder(int concurrency) {
        return new BatchedBuilder(concurrency);
    }

    private ShardedLongLongMap(HugeLongArray internalNodeMapping, LongLongMap[] originalNodeMappingShards, int shardShift, int shardMask, long maxOriginalId) {
        this.internalNodeMapping = internalNodeMapping;
        this.originalNodeMappingShards = originalNodeMappingShards;
        this.shardShift = shardShift;
        this.shardMask = shardMask;
        this.maxOriginalId = maxOriginalId;
    }

    public long toMappedNodeId(long nodeId) {
        LongLongMap shard = ShardedLongLongMap.findShard(nodeId, this.originalNodeMappingShards, this.shardShift, this.shardMask);
        return shard.getIfAbsent(nodeId, -1L);
    }

    public boolean contains(long originalId) {
        LongLongMap shard = ShardedLongLongMap.findShard(originalId, this.originalNodeMappingShards, this.shardShift, this.shardMask);
        return shard.containsKey(originalId);
    }

    public long toOriginalNodeId(long nodeId) {
        return this.internalNodeMapping.get(nodeId);
    }

    public long maxOriginalId() {
        return this.maxOriginalId;
    }

    public long size() {
        return this.internalNodeMapping.size();
    }

    private static <T> T findShard(long key, T[] shards, int shift, int mask) {
        int idx = ShardedLongLongMap.shardIdx2(key, shift, mask);
        return shards[idx];
    }

    private static int shardIdx(long key, int shift, int mask) {
        return (int)(key % (long)mask);
    }

    private static int shardIdx2(long key, int shift, int mask) {
        long hash = SpreadFunctions.longSpreadOne((long)key);
        return (int)(hash >>> shift);
    }

    private static int numberOfShards(int concurrency) {
        return BitUtil.nextHighestPowerOfTwo((int)(concurrency * 4));
    }

    private static <S extends MapShard> ShardedLongLongMap build(long nodeCount, S[] shards, int shardShift, int shardMask) {
        HugeLongArray internalNodeMapping = HugeLongArray.newArray(nodeCount);
        LongLongMap[] mapShards = new LongLongMap[shards.length];
        long[] maxOriginalIds = new long[shards.length];
        Arrays.parallelSetAll(mapShards, idx -> {
            MutableLong maxOriginalId = new MutableLong(0L);
            MapShard shard = shards[idx];
            MutableLongLongMap mapping = shard.intoMapping();
            mapping.forEachKeyValue((LongLongProcedure & Serializable)(originalId, mappedId) -> {
                if (originalId > maxOriginalId.longValue()) {
                    maxOriginalId.setValue(originalId);
                }
                internalNodeMapping.set(mappedId, originalId);
            });
            maxOriginalIds[idx] = maxOriginalId.longValue();
            return mapping;
        });
        return new ShardedLongLongMap(internalNodeMapping, mapShards, shardShift, shardMask, Arrays.stream(maxOriginalIds).max().orElse(0L));
    }

    private static <S extends MapShard> ShardedLongLongMap build(long nodeCount, S[] shards, int shardShift, int shardMask, long maxOriginalId) {
        HugeLongArray internalNodeMapping = HugeLongArray.newArray(nodeCount);
        LongLongMap[] mapShards = new LongLongMap[shards.length];
        Arrays.parallelSetAll(mapShards, idx -> {
            MapShard shard = shards[idx];
            MutableLongLongMap mapping = shard.intoMapping();
            mapping.forEachKeyValue((LongLongProcedure & Serializable)(originalId, mappedId) -> internalNodeMapping.set(mappedId, originalId));
            return mapping;
        });
        return new ShardedLongLongMap(internalNodeMapping, mapShards, shardShift, shardMask, maxOriginalId);
    }

    public static final class BatchedBuilder {
        private final AtomicLong nodeCount = new AtomicLong();
        private final Shard[] shards;
        private final CloseableThreadLocal<Batch> batches;
        private final int shardShift;
        private final int shardMask;

        BatchedBuilder(int concurrency) {
            int numberOfShards = ShardedLongLongMap.numberOfShards(concurrency);
            this.shardShift = 64 - Integer.numberOfTrailingZeros(numberOfShards);
            this.shardMask = numberOfShards - 1;
            this.shards = (Shard[])IntStream.range(0, numberOfShards).mapToObj(__ -> new Shard()).toArray(Shard[]::new);
            this.batches = CloseableThreadLocal.withInitial(() -> new Batch(this.shards, this.shardShift, this.shardMask));
        }

        public Batch prepareBatch(int nodeCount) {
            long startId = this.nodeCount.getAndAdd(nodeCount);
            Batch batch = this.batches.get();
            batch.initBatch(startId, nodeCount);
            return batch;
        }

        public ShardedLongLongMap build() {
            this.batches.close();
            return ShardedLongLongMap.build((long)this.nodeCount.get(), (MapShard[])this.shards, (int)this.shardShift, (int)this.shardMask);
        }

        public ShardedLongLongMap build(long maxOriginalId) {
            this.batches.close();
            return ShardedLongLongMap.build((long)this.nodeCount.get(), (MapShard[])this.shards, (int)this.shardShift, (int)this.shardMask, (long)maxOriginalId);
        }

        private static final class Shard
        extends MapShard {
            private Shard() {
            }

            void addNode(long nodeId, long mappedId) {
                this.assertIsUnderLock();
                this.mapping.put(nodeId, mappedId);
            }
        }

        public static final class Batch
        implements IdMapAllocator {
            private final Shard[] shards;
            private final int shardShift;
            private final int shardMask;
            private long startId;
            private int length;

            private Batch(Shard[] shards, int shardShift, int shardMask) {
                this.shards = shards;
                this.shardShift = shardShift;
                this.shardMask = shardMask;
            }

            @Override
            public long startId() {
                return this.startId;
            }

            @Override
            public int allocatedSize() {
                return this.length;
            }

            @Override
            public void insert(long[] nodeIds) {
                int length = this.allocatedSize();
                for (int i = 0; i < length; ++i) {
                    this.addNode(nodeIds[i]);
                }
            }

            public long addNode(long nodeId) {
                long mappedId = this.startId++;
                Shard shard = ShardedLongLongMap.findShard(nodeId, this.shards, this.shardShift, this.shardMask);
                try (AbstractMultiReaderMutableCollection.LockWrapper ignoredLock = shard.acquireLock();){
                    shard.addNode(nodeId, mappedId);
                }
                return mappedId;
            }

            void initBatch(long startId, int length) {
                this.startId = startId;
                this.length = length;
            }
        }
    }

    public static final class Builder {
        private final AtomicLong nodeCount = new AtomicLong();
        private final Shard[] shards;
        private final int shardShift;
        private final int shardMask;

        Builder(int concurrency) {
            int numberOfShards = ShardedLongLongMap.numberOfShards(concurrency);
            this.shardShift = 64 - Integer.numberOfTrailingZeros(numberOfShards);
            this.shardMask = numberOfShards - 1;
            this.shards = (Shard[])IntStream.range(0, numberOfShards).mapToObj(__ -> new Shard(this.nodeCount)).toArray(Shard[]::new);
        }

        public long addNode(long nodeId) {
            Shard shard = ShardedLongLongMap.findShard(nodeId, this.shards, this.shardShift, this.shardMask);
            try (AbstractMultiReaderMutableCollection.LockWrapper ignoredLock = shard.acquireLock();){
                long l = shard.addNode(nodeId);
                return l;
            }
        }

        public long toMappedNodeId(long nodeId) {
            Shard shard = ShardedLongLongMap.findShard(nodeId, this.shards, this.shardShift, this.shardMask);
            return shard.toMappedNodeId(nodeId);
        }

        public ShardedLongLongMap build() {
            return ShardedLongLongMap.build((long)this.nodeCount.get(), (MapShard[])this.shards, (int)this.shardShift, (int)this.shardMask);
        }

        public ShardedLongLongMap build(long maxOriginalId) {
            return ShardedLongLongMap.build((long)this.nodeCount.get(), (MapShard[])this.shards, (int)this.shardShift, (int)this.shardMask, (long)maxOriginalId);
        }

        private static final class Shard
        extends MapShard {
            private final AtomicLong nextId;

            private Shard(AtomicLong nextId) {
                this.nextId = nextId;
            }

            long toMappedNodeId(long nodeId) {
                return this.mapping.getIfAbsent(nodeId, -1L);
            }

            long addNode(long nodeId) {
                this.assertIsUnderLock();
                long internalId = this.nextId.getAndIncrement();
                this.mapping.put(nodeId, internalId);
                return internalId;
            }
        }
    }

    static abstract class MapShard {
        private final ReentrantLock lock;
        private final AbstractMultiReaderMutableCollection.LockWrapper lockWrapper;
        final MutableLongLongMap mapping = LongLongMaps.mutable.empty();

        MapShard() {
            this.lock = new ReentrantLock();
            this.lockWrapper = new AbstractMultiReaderMutableCollection.LockWrapper((Lock)this.lock);
        }

        final AbstractMultiReaderMutableCollection.LockWrapper acquireLock() {
            this.lock.lock();
            return this.lockWrapper;
        }

        void assertIsUnderLock() {
            assert (this.lock.isHeldByCurrentThread()) : "addNode must only be called while holding the lock";
        }

        MutableLongLongMap intoMapping() {
            return this.mapping;
        }
    }
}

