package ghidra.features.bsim.query.file;

import generic.lsh.vector.LSHVector;
import ghidra.features.bsim.query.client.tables.CachedStatement;
import ghidra.features.bsim.query.client.tables.SQLComplexTable;
import ghidra.features.bsim.query.description.VectorResult;
import ghidra.features.bsim.query.elastic.Base64Lite;
import ghidra.features.bsim.query.elastic.Base64VectorFactory;
import java.io.IOException;
import java.io.StringReader;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:ghidra/features/bsim/query/file/H2VectorTable.class */
public class H2VectorTable extends SQLComplexTable {
    public static final String TABLE_NAME = "h2_vectable";
    private final Base64VectorFactory vectorFactory;
    private final VectorStore vectorStore;
    private final CachedStatement<PreparedStatement> insert_stmt;
    private final CachedStatement<PreparedStatement> select_by_rowid_stmt;
    private final CachedStatement<PreparedStatement> select_id_by_hash_stmt;
    private final CachedStatement<PreparedStatement> update_by_hash_stmt;
    private final CachedStatement<PreparedStatement> select_count_by_rowid_stmt;
    private final CachedStatement<PreparedStatement> update_by_rowid_stmt;

    public H2VectorTable(Base64VectorFactory base64VectorFactory, VectorStore vectorStore) {
        super(TABLE_NAME, "id");
        this.insert_stmt = new CachedStatement<>();
        this.select_by_rowid_stmt = new CachedStatement<>();
        this.select_id_by_hash_stmt = new CachedStatement<>();
        this.update_by_hash_stmt = new CachedStatement<>();
        this.select_count_by_rowid_stmt = new CachedStatement<>();
        this.update_by_rowid_stmt = new CachedStatement<>();
        this.vectorFactory = base64VectorFactory;
        this.vectorStore = vectorStore;
    }

    @Override // ghidra.features.bsim.query.client.tables.SQLComplexTable
    public void close() {
        this.insert_stmt.close();
        this.select_by_rowid_stmt.close();
        this.select_id_by_hash_stmt.close();
        this.update_by_hash_stmt.close();
        this.select_count_by_rowid_stmt.close();
        this.update_by_rowid_stmt.close();
        super.close();
    }

    @Override // ghidra.features.bsim.query.client.tables.SQLComplexTable
    public void create(Statement statement) throws SQLException {
        statement.executeUpdate("CREATE TABLE h2_vectable(id SERIAL PRIMARY KEY, count INTEGER, vec_hash BIGINT, vec CLOB)");
        statement.executeUpdate("CREATE UNIQUE INDEX h2_vectable_index ON h2_vectable (vec_hash)");
    }

    @Override // ghidra.features.bsim.query.client.tables.SQLComplexTable
    public void drop(Statement statement) throws SQLException {
        this.vectorStore.invalidate();
        statement.executeUpdate("DROP INDEX h2_vectable_index");
        super.drop(statement);
    }

    @Override // ghidra.features.bsim.query.client.tables.SQLComplexTable
    public long insert(Object... objArr) throws SQLException {
        if (objArr == null || objArr.length != 2) {
            throw new IllegalArgumentException("Insert method for H2VectorTable accepts two arguments: count(int) and LSHVector");
        }
        int intValue = ((Integer) objArr[0]).intValue();
        LSHVector lSHVector = (LSHVector) objArr[1];
        PreparedStatement prepareIfNeeded = this.insert_stmt.prepareIfNeeded(() -> {
            return this.f102db.prepareStatement("INSERT INTO h2_vectable (count,vec_hash,vec) VALUES(?,?,?)", 1);
        });
        StringBuilder sb = new StringBuilder();
        lSHVector.saveBase64(sb, Base64Lite.encode);
        prepareIfNeeded.setInt(1, intValue);
        prepareIfNeeded.setLong(2, lSHVector.calcUniqueHash());
        prepareIfNeeded.setString(3, sb.toString());
        if (prepareIfNeeded.executeUpdate() != 1) {
            throw new SQLException("Insert failed for vector table");
        }
        ResultSet generatedKeys = prepareIfNeeded.getGeneratedKeys();
        try {
            if (!generatedKeys.next()) {
                throw new SQLException("Unable to obtain vector id for insert");
            }
            long j = generatedKeys.getLong(1);
            if (generatedKeys != null) {
                generatedKeys.close();
            }
            this.vectorStore.update(new VectorStoreEntry(j, lSHVector, intValue, this.vectorFactory.getSelfSignificance(lSHVector)));
            return j;
        } catch (Throwable th) {
            if (generatedKeys != null) {
                try {
                    generatedKeys.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public Map<Long, VectorStoreEntry> readVectors() throws SQLException {
        char[] allocateBuffer = Base64VectorFactory.allocateBuffer();
        HashMap hashMap = new HashMap();
        try {
            Statement createStatement = this.f102db.createStatement();
            try {
                ResultSet executeQuery = createStatement.executeQuery("SELECT id,count,vec FROM h2_vectable");
                while (executeQuery.next()) {
                    try {
                        long j = executeQuery.getLong(1);
                        int i = executeQuery.getInt(2);
                        LSHVector restoreVectorFromBase64 = this.vectorFactory.restoreVectorFromBase64(new StringReader(executeQuery.getString(3)), allocateBuffer);
                        hashMap.put(Long.valueOf(j), new VectorStoreEntry(j, restoreVectorFromBase64, i, this.vectorFactory.getSelfSignificance(restoreVectorFromBase64)));
                    } catch (Throwable th) {
                        if (executeQuery != null) {
                            try {
                                executeQuery.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                if (executeQuery != null) {
                    executeQuery.close();
                }
                if (createStatement != null) {
                    createStatement.close();
                }
                return hashMap;
            } finally {
            }
        } catch (IOException e) {
            throw new SQLException(e);
        }
    }

    public VectorResult queryVectorById(long j) throws SQLException {
        VectorStoreEntry vectorById = this.vectorStore.getVectorById(j);
        if (vectorById != null) {
            return new VectorResult(j, vectorById.count(), 0.0d, 0.0d, vectorById.vec());
        }
        PreparedStatement prepareIfNeeded = this.select_by_rowid_stmt.prepareIfNeeded(() -> {
            return this.f102db.prepareStatement("SELECT id,count,vec FROM h2_vectable WHERE id = ?");
        });
        prepareIfNeeded.setLong(1, j);
        ResultSet executeQuery = prepareIfNeeded.executeQuery();
        try {
            if (!executeQuery.next()) {
                throw new SQLException("Bad vector table rowid");
            }
            char[] allocateBuffer = Base64VectorFactory.allocateBuffer();
            try {
                VectorResult vectorResult = new VectorResult();
                vectorResult.vectorid = executeQuery.getLong(1);
                vectorResult.hitcount = executeQuery.getInt(2);
                vectorResult.vec = this.vectorFactory.restoreVectorFromBase64(new StringReader(executeQuery.getString(3)), allocateBuffer);
                if (executeQuery != null) {
                    executeQuery.close();
                }
                return vectorResult;
            } catch (IOException e) {
                throw new SQLException(e.getMessage());
            }
        } catch (Throwable th) {
            if (executeQuery != null) {
                try {
                    executeQuery.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private int queryVectorCountById(long j) throws SQLException {
        PreparedStatement prepareIfNeeded = this.select_count_by_rowid_stmt.prepareIfNeeded(() -> {
            return this.f102db.prepareStatement("SELECT count FROM h2_vectable WHERE id = ?");
        });
        prepareIfNeeded.setLong(1, j);
        ResultSet executeQuery = prepareIfNeeded.executeQuery();
        try {
            if (!executeQuery.next()) {
                throw new SQLException("Bad vector table rowid");
            }
            int i = executeQuery.getInt(1);
            if (executeQuery != null) {
                executeQuery.close();
            }
            return i;
        } catch (Throwable th) {
            if (executeQuery != null) {
                try {
                    executeQuery.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public long updateVector(LSHVector lSHVector, int i) throws SQLException {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid countDiff: " + i);
        }
        PreparedStatement prepareIfNeeded = this.update_by_hash_stmt.prepareIfNeeded(() -> {
            return this.f102db.prepareStatement("UPDATE h2_vectable SET count = count + ? WHERE vec_hash = ?");
        });
        long calcUniqueHash = lSHVector.calcUniqueHash();
        prepareIfNeeded.setInt(1, i);
        prepareIfNeeded.setLong(2, calcUniqueHash);
        int executeUpdate = prepareIfNeeded.executeUpdate();
        if (executeUpdate == 0) {
            return insert(Integer.valueOf(i), lSHVector);
        }
        if (executeUpdate > 1) {
            throw new SQLException("Unexpected updated row count: " + executeUpdate);
        }
        PreparedStatement prepareIfNeeded2 = this.select_id_by_hash_stmt.prepareIfNeeded(() -> {
            return this.f102db.prepareStatement("SELECT id, count FROM h2_vectable WHERE vec_hash = ?");
        });
        prepareIfNeeded2.setLong(1, calcUniqueHash);
        ResultSet executeQuery = prepareIfNeeded2.executeQuery();
        try {
            if (!executeQuery.next()) {
                throw new SQLException("Unknown vector hash");
            }
            long j = executeQuery.getLong(1);
            int i2 = executeQuery.getInt(2);
            if (executeQuery != null) {
                executeQuery.close();
            }
            this.vectorStore.update(new VectorStoreEntry(j, lSHVector, i2, this.vectorFactory.getSelfSignificance(lSHVector)));
            return j;
        } catch (Throwable th) {
            if (executeQuery != null) {
                try {
                    executeQuery.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public int deleteVector(long j, int i) throws SQLException {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid countDiff: " + i);
        }
        PreparedStatement prepareIfNeeded = this.update_by_rowid_stmt.prepareIfNeeded(() -> {
            return this.f102db.prepareStatement("UPDATE h2_vectable SET count = count - ? WHERE id = ? AND count >= ?");
        });
        prepareIfNeeded.setInt(1, i);
        prepareIfNeeded.setLong(2, j);
        prepareIfNeeded.setInt(3, i);
        int executeUpdate = prepareIfNeeded.executeUpdate();
        if (executeUpdate == 0) {
            return -1;
        }
        if (executeUpdate > 1) {
            throw new SQLException("Unexpected updated row count: " + executeUpdate);
        }
        int queryVectorCountById = queryVectorCountById(j);
        if (queryVectorCountById > 0) {
            this.vectorStore.update(j, queryVectorCountById);
            return 0;
        }
        delete(j);
        return 1;
    }

    @Override // ghidra.features.bsim.query.client.tables.SQLComplexTable
    public int delete(long j) throws SQLException {
        int delete = super.delete(j);
        this.vectorStore.delete(j);
        return delete;
    }
}
