package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;

/* loaded from: input_file:org/apache/lucene/util/hnsw/HnswGraphBuilder.class */
public final class HnswGraphBuilder<T> {
    public static final int DEFAULT_MAX_CONN = 16;
    public static final int DEFAULT_BEAM_WIDTH = 100;
    private static final long DEFAULT_RAND_SEED = 42;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed;
    private final int M;
    private final int beamWidth;
    private final double ml;
    private final NeighborArray scratch;
    private final VectorSimilarityFunction similarityFunction;
    private final VectorEncoding vectorEncoding;
    private final RandomAccessVectorValues<T> vectors;
    private final SplittableRandom random;
    private final HnswGraphSearcher<T> graphSearcher;
    final OnHeapHnswGraph hnsw;
    private InfoStream infoStream = InfoStream.getDefault();
    private final RandomAccessVectorValues<T> vectorsCopy;
    private final Set<Integer> initializedNodes;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static <T> HnswGraphBuilder<T> create(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, long j) throws IOException {
        return new HnswGraphBuilder<>(randomAccessVectorValues, vectorEncoding, vectorSimilarityFunction, i, i2, j);
    }

    public static <T> HnswGraphBuilder<T> create(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, long j, HnswGraph hnswGraph, Map<Integer, Integer> map) throws IOException {
        HnswGraphBuilder<T> hnswGraphBuilder = new HnswGraphBuilder<>(randomAccessVectorValues, vectorEncoding, vectorSimilarityFunction, i, i2, j);
        hnswGraphBuilder.initializeFromGraph(hnswGraph, map);
        return hnswGraphBuilder;
    }

    private HnswGraphBuilder(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, long j) throws IOException {
        this.vectors = randomAccessVectorValues;
        this.vectorsCopy = randomAccessVectorValues.copy();
        this.vectorEncoding = (VectorEncoding) Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = (VectorSimilarityFunction) Objects.requireNonNull(vectorSimilarityFunction);
        if (i <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.M = i;
        this.beamWidth = i2;
        this.ml = i == 1 ? 1.0d : 1.0d / Math.log(1.0d * i);
        this.random = new SplittableRandom(j);
        this.hnsw = new OnHeapHnswGraph(i);
        this.graphSearcher = new HnswGraphSearcher<>(vectorEncoding, vectorSimilarityFunction, new NeighborQueue(i2, true), new FixedBitSet(this.vectors.size()));
        this.scratch = new NeighborArray(Math.max(i2, i + 1), false);
        this.initializedNodes = new HashSet();
    }

    public OnHeapHnswGraph build(RandomAccessVectorValues<T> randomAccessVectorValues) throws IOException {
        if (randomAccessVectorValues == this.vectors) {
            throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + randomAccessVectorValues.size() + " vectors");
        }
        addVectors(randomAccessVectorValues);
        return this.hnsw;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:21:0x00ba, code lost:
    
        r0 = r5.hnsw.getNeighbors(r10, r0);
        r6.seek(r10, r0);
        r0 = r6.nextNeighbor();
     */
    /* JADX WARN: Code restructure failed: missing block: B:22:0x00d5, code lost:
    
        r15 = r0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:23:0x00d9, code lost:
    
        if (r15 == Integer.MAX_VALUE) goto L34;
     */
    /* JADX WARN: Code restructure failed: missing block: B:24:0x00dc, code lost:
    
        r0 = r7.get(java.lang.Integer.valueOf(r15)).intValue();
     */
    /* JADX WARN: Code restructure failed: missing block: B:25:0x00fa, code lost:
    
        switch(org.apache.lucene.util.hnsw.HnswGraphBuilder.AnonymousClass1.$SwitchMap$org$apache$lucene$index$VectorEncoding[r5.vectorEncoding.ordinal()]) {
            case 1: goto L26;
            case 2: goto L27;
            default: goto L26;
        };
     */
    /* JADX WARN: Code restructure failed: missing block: B:26:0x0114, code lost:
    
        r0 = r5.similarityFunction.compare(r8, (float[]) r5.vectorsCopy.vectorValue(r0));
     */
    /* JADX WARN: Code restructure failed: missing block: B:28:0x0148, code lost:
    
        r0.addOutOfOrder(r0, r0);
        r0 = r6.nextNeighbor();
     */
    /* JADX WARN: Code restructure failed: missing block: B:29:0x012f, code lost:
    
        r0 = r5.similarityFunction.compare(r9, (byte[]) r5.vectorsCopy.vectorValue(r0));
     */
    /* JADX WARN: Removed duplicated region for block: B:14:0x0039  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private void initializeFromGraph(org.apache.lucene.util.hnsw.HnswGraph r6, java.util.Map<java.lang.Integer, java.lang.Integer> r7) throws java.io.IOException {
        /*
            Method dump skipped, instructions count: 356
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.lucene.util.hnsw.HnswGraphBuilder.initializeFromGraph(org.apache.lucene.util.hnsw.HnswGraph, java.util.Map):void");
    }

    private void addVectors(RandomAccessVectorValues<T> randomAccessVectorValues) throws IOException {
        long nanoTime = System.nanoTime();
        long j = nanoTime;
        for (int i = 0; i < randomAccessVectorValues.size(); i++) {
            if (!this.initializedNodes.contains(Integer.valueOf(i))) {
                addGraphNode(i, (RandomAccessVectorValues) randomAccessVectorValues);
                if (i % 10000 == 0 && this.infoStream.isEnabled(HNSW_COMPONENT)) {
                    j = printGraphBuildStatus(i, nanoTime, j);
                }
            }
        }
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    public OnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    public void addGraphNode(int i, T t) throws IOException {
        int randomGraphLevel = getRandomGraphLevel(this.ml, this.random);
        int numLevels = this.hnsw.numLevels() - 1;
        if (this.hnsw.entryNode() == -1) {
            for (int i2 = randomGraphLevel; i2 >= 0; i2--) {
                this.hnsw.addNode(i2, i);
            }
            return;
        }
        int[] iArr = {this.hnsw.entryNode()};
        for (int i3 = randomGraphLevel; i3 > numLevels; i3--) {
            this.hnsw.addNode(i3, i);
        }
        for (int i4 = numLevels; i4 > randomGraphLevel; i4--) {
            iArr = new int[]{this.graphSearcher.searchLevel(t, 1, i4, iArr, this.vectors, this.hnsw).pop()};
        }
        for (int min = Math.min(randomGraphLevel, numLevels); min >= 0; min--) {
            NeighborQueue searchLevel = this.graphSearcher.searchLevel(t, this.beamWidth, min, iArr, this.vectors, this.hnsw);
            iArr = searchLevel.nodes();
            this.hnsw.addNode(min, i);
            addDiverseNeighbors(min, i, searchLevel);
        }
    }

    public void addGraphNode(int i, RandomAccessVectorValues<T> randomAccessVectorValues) throws IOException {
        addGraphNode(i, (int) randomAccessVectorValues.vectorValue(i));
    }

    private long printGraphBuildStatus(int i, long j, long j2) {
        long nanoTime = System.nanoTime();
        this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", Integer.valueOf(i), Long.valueOf(TimeUnit.NANOSECONDS.toMillis(nanoTime - j2)), Long.valueOf(TimeUnit.NANOSECONDS.toMillis(nanoTime - j))));
        return nanoTime;
    }

    private void addDiverseNeighbors(int i, int i2, NeighborQueue neighborQueue) throws IOException {
        NeighborArray neighbors = this.hnsw.getNeighbors(i, i2);
        if (!$assertionsDisabled && neighbors.size() != 0) {
            throw new AssertionError();
        }
        popToScratch(neighborQueue);
        int i3 = i == 0 ? this.M * 2 : this.M;
        selectAndLinkDiverse(neighbors, this.scratch, i3);
        int size = neighbors.size();
        for (int i4 = 0; i4 < size; i4++) {
            NeighborArray neighbors2 = this.hnsw.getNeighbors(i, neighbors.node[i4]);
            neighbors2.addOutOfOrder(i2, neighbors.score[i4]);
            if (neighbors2.size() > i3) {
                neighbors2.removeIndex(findWorstNonDiverse(neighbors2));
            }
        }
    }

    private void selectAndLinkDiverse(NeighborArray neighborArray, NeighborArray neighborArray2, int i) throws IOException {
        for (int size = neighborArray2.size() - 1; neighborArray.size() < i && size >= 0; size--) {
            int i2 = neighborArray2.node[size];
            float f = neighborArray2.score[size];
            if (!$assertionsDisabled && i2 >= this.hnsw.size()) {
                throw new AssertionError();
            }
            if (diversityCheck(i2, f, neighborArray)) {
                neighborArray.addInOrder(i2, f);
            }
        }
    }

    private void popToScratch(NeighborQueue neighborQueue) {
        this.scratch.clear();
        int size = neighborQueue.size();
        for (int i = 0; i < size; i++) {
            this.scratch.addInOrder(neighborQueue.pop(), neighborQueue.topScore());
        }
    }

    private boolean diversityCheck(int i, float f, NeighborArray neighborArray) throws IOException {
        return isDiverse(i, neighborArray, f);
    }

    private boolean isDiverse(int i, NeighborArray neighborArray, float f) throws IOException {
        switch (this.vectorEncoding) {
            case FLOAT32:
            default:
                return isDiverse((float[]) this.vectors.vectorValue(i), neighborArray, f);
            case BYTE:
                return isDiverse((byte[]) this.vectors.vectorValue(i), neighborArray, f);
        }
    }

    private boolean isDiverse(float[] fArr, NeighborArray neighborArray, float f) throws IOException {
        for (int i = 0; i < neighborArray.size(); i++) {
            if (this.similarityFunction.compare(fArr, (float[]) this.vectorsCopy.vectorValue(neighborArray.node[i])) >= f) {
                return false;
            }
        }
        return true;
    }

    private boolean isDiverse(byte[] bArr, NeighborArray neighborArray, float f) throws IOException {
        for (int i = 0; i < neighborArray.size(); i++) {
            if (this.similarityFunction.compare(bArr, (byte[]) this.vectorsCopy.vectorValue(neighborArray.node[i])) >= f) {
                return false;
            }
        }
        return true;
    }

    private int findWorstNonDiverse(NeighborArray neighborArray) throws IOException {
        int[] sort = neighborArray.sort();
        if (sort == null) {
            return neighborArray.size() - 1;
        }
        int length = sort.length - 1;
        for (int size = neighborArray.size() - 1; size > 0 && length >= 0; size--) {
            if (isWorstNonDiverse(size, neighborArray, sort, length)) {
                return size;
            }
            if (size == sort[length]) {
                length--;
            }
        }
        return neighborArray.size() - 1;
    }

    private boolean isWorstNonDiverse(int i, NeighborArray neighborArray, int[] iArr, int i2) throws IOException {
        int i3 = neighborArray.node[i];
        switch (this.vectorEncoding) {
            case FLOAT32:
            default:
                return isWorstNonDiverse(i, (float[]) this.vectors.vectorValue(i3), neighborArray, iArr, i2);
            case BYTE:
                return isWorstNonDiverse(i, (byte[]) this.vectors.vectorValue(i3), neighborArray, iArr, i2);
        }
    }

    private boolean isWorstNonDiverse(int i, float[] fArr, NeighborArray neighborArray, int[] iArr, int i2) throws IOException {
        float f = neighborArray.score[i];
        if (i == iArr[i2]) {
            for (int i3 = i - 1; i3 >= 0; i3--) {
                if (this.similarityFunction.compare(fArr, (float[]) this.vectorsCopy.vectorValue(neighborArray.node[i3])) >= f) {
                    return true;
                }
            }
            return false;
        }
        if (!$assertionsDisabled && i <= iArr[i2]) {
            throw new AssertionError();
        }
        for (int i4 = i2; i4 >= 0; i4--) {
            if (this.similarityFunction.compare(fArr, (float[]) this.vectorsCopy.vectorValue(neighborArray.node[iArr[i4]])) >= f) {
                return true;
            }
        }
        return false;
    }

    private boolean isWorstNonDiverse(int i, byte[] bArr, NeighborArray neighborArray, int[] iArr, int i2) throws IOException {
        float f = neighborArray.score[i];
        if (i == iArr[i2]) {
            for (int i3 = i - 1; i3 >= 0; i3--) {
                if (this.similarityFunction.compare(bArr, (byte[]) this.vectorsCopy.vectorValue(neighborArray.node[i3])) >= f) {
                    return true;
                }
            }
            return false;
        }
        if (!$assertionsDisabled && i <= iArr[i2]) {
            throw new AssertionError();
        }
        for (int i4 = i2; i4 >= 0; i4--) {
            if (this.similarityFunction.compare(bArr, (byte[]) this.vectorsCopy.vectorValue(neighborArray.node[iArr[i4]])) >= f) {
                return true;
            }
        }
        return false;
    }

    private static int getRandomGraphLevel(double d, SplittableRandom splittableRandom) {
        double nextDouble;
        do {
            nextDouble = splittableRandom.nextDouble();
        } while (nextDouble == 0.0d);
        return (int) ((-Math.log(nextDouble)) * d);
    }

    static {
        $assertionsDisabled = !HnswGraphBuilder.class.desiredAssertionStatus();
        randSeed = DEFAULT_RAND_SEED;
    }
}
