package ghidra.features.bsim.query.client;

import generic.lsh.vector.LSHVector;
import generic.lsh.vector.WeightedLSHCosineVectorFactory;
import ghidra.features.bsim.query.BSimPostgresDBConnectionManager;
import ghidra.features.bsim.query.BSimServerInfo;
import ghidra.features.bsim.query.FunctionDatabase;
import ghidra.features.bsim.query.LSHException;
import ghidra.features.bsim.query.client.tables.CachedStatement;
import ghidra.features.bsim.query.description.FunctionDescription;
import ghidra.features.bsim.query.description.SignatureRecord;
import ghidra.features.bsim.query.description.VectorResult;
import ghidra.features.bsim.query.protocol.AdjustVectorIndex;
import ghidra.features.bsim.query.protocol.BSimQuery;
import ghidra.features.bsim.query.protocol.PasswordChange;
import ghidra.features.bsim.query.protocol.PrewarmRequest;
import ghidra.features.bsim.query.protocol.QueryNearestVector;
import ghidra.features.bsim.query.protocol.QueryResponseRecord;
import ghidra.features.bsim.query.protocol.ResponseAdjustIndex;
import ghidra.features.bsim.query.protocol.ResponseNearestVector;
import ghidra.features.bsim.query.protocol.ResponsePassword;
import ghidra.features.bsim.query.protocol.SimilarityVectorResult;
import java.io.IOException;
import java.net.URL;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:ghidra/features/bsim/query/client/PostgresFunctionDatabase.class */
public final class PostgresFunctionDatabase extends AbstractSQLFunctionDatabase<WeightedLSHCosineVectorFactory> {
    public static final int LAYOUT_VERSION = 6;
    private static final String DEFAULT_DATABASE_NAME = "postgres";
    private BSimPostgresDBConnectionManager.BSimPostgresDataSource postgresDs;
    private boolean asynchronous;
    private final CachedStatement<Statement> reusableStatement;
    private final CachedStatement<PreparedStatement> selectVectorByRowIdStatement;
    private final CachedStatement<PreparedStatement> selectNearestVectorStatement;

    public PostgresFunctionDatabase(URL url, boolean z) {
        super(BSimPostgresDBConnectionManager.getDataSource(url), FunctionDatabase.generateLSHVectorFactory(), 6);
        this.reusableStatement = new CachedStatement<>();
        this.selectVectorByRowIdStatement = new CachedStatement<>();
        this.selectNearestVectorStatement = new CachedStatement<>();
        this.postgresDs = (BSimPostgresDBConnectionManager.BSimPostgresDataSource) this.ds;
        this.asynchronous = z;
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase, ghidra.features.bsim.query.FunctionDatabase, java.lang.AutoCloseable
    public void close() {
        this.reusableStatement.close();
        this.selectVectorByRowIdStatement.close();
        this.selectNearestVectorStatement.close();
        super.close();
    }

    private Statement getReusableStatement() throws SQLException {
        return this.reusableStatement.prepareIfNeeded(() -> {
            return initConnection().createStatement();
        });
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected void lockTablesForWrite() throws SQLException {
        getReusableStatement().execute("LOCK TABLE exetable, desctable, vectable IN SHARE ROW EXCLUSIVE MODE");
    }

    private void changePassword(Connection connection, String str, char[] cArr) throws SQLException {
        StringBuilder sb = new StringBuilder();
        sb.append("ALTER ROLE \"");
        sb.append(str);
        sb.append("\" WITH PASSWORD '");
        for (char c : cArr) {
            if (c == '\'') {
                sb.append(c);
            }
            sb.append(c);
        }
        sb.append('\'');
        Statement createStatement = connection.createStatement();
        try {
            createStatement.executeUpdate(sb.toString());
            this.postgresDs.setPassword(str, cArr);
            if (createStatement != null) {
                createStatement.close();
            }
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void createVectorFunctions(Statement statement) throws SQLException {
        statement.executeUpdate("CREATE FUNCTION insert_vec(newvec lshvector,OUT ourhash BIGINT) AS $$ DECLARE  curs1 CURSOR (key BIGINT) FOR SELECT count FROM vectable WHERE id = key FOR UPDATE;  ourcount INTEGER; BEGIN  ourhash := lshvector_hash(newvec);  OPEN curs1( ourhash );  FETCH curs1 INTO ourcount;  IF FOUND THEN    UPDATE vectable SET count = ourcount + 1 WHERE CURRENT OF curs1;  ELSE    INSERT INTO vectable (id,count,vec) VALUES(ourhash,1,newvec);  END IF;  CLOSE curs1; END; $$ LANGUAGE plpgsql;");
        statement.executeUpdate("CREATE FUNCTION remove_vec(vecid BIGINT,countdiff INTEGER) RETURNS INTEGER AS $$DECLARE  curs1 CURSOR (key BIGINT) FOR SELECT count FROM vectable WHERE id = key FOR UPDATE;  ourcount INTEGER;  rescode INTEGER;BEGIN  rescode = -1;  OPEN curs1( vecid );  FETCH curs1 INTO ourcount;  IF FOUND AND ourcount > countdiff THEN    UPDATE vectable SET count = ourcount - countdiff WHERE CURRENT OF curs1;    rescode = 0;  ELSIF FOUND THEN    DELETE FROM vectable WHERE CURRENT OF curs1;    rescode = 1;  END IF;  CLOSE curs1;  RETURN rescode;END;$$ LANGUAGE plpgsql;");
    }

    private void serverLoadWeights(Connection connection) throws SQLException {
        Statement createStatement = connection.createStatement();
        try {
            ResultSet executeQuery = createStatement.executeQuery("SELECT lsh_load()");
            do {
                try {
                } catch (Throwable th) {
                    if (executeQuery != null) {
                        try {
                            executeQuery.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } while (executeQuery.next());
            if (executeQuery != null) {
                executeQuery.close();
            }
            if (createStatement != null) {
                createStatement.close();
            }
        } catch (Throwable th3) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    public void initializeDatabase(Configuration configuration) throws SQLException {
        Connection initConnection = initConnection();
        serverLoadWeights(initConnection);
        Statement createStatement = initConnection.createStatement();
        try {
            if (this.asynchronous) {
                createStatement.executeUpdate("SET SESSION synchronous_commit TO OFF");
            } else {
                createStatement.executeUpdate("SET SESSION synchronous_commit to ON");
            }
            if (createStatement != null) {
                createStatement.close();
            }
            super.initializeDatabase(configuration);
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected void generateRawDatabase() throws SQLException {
        BSimServerInfo serverInfo = this.postgresDs.getServerInfo();
        BSimServerInfo bSimServerInfo = new BSimServerInfo(BSimServerInfo.DBType.postgres, serverInfo.getServerName(), serverInfo.getPort(), DEFAULT_DATABASE_NAME);
        String str = "CREATE DATABASE \"" + serverInfo.getDBName() + "\"";
        BSimPostgresDBConnectionManager.BSimPostgresDataSource dataSource = BSimPostgresDBConnectionManager.getDataSource(bSimServerInfo);
        try {
            Connection connection = dataSource.getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.executeUpdate(str);
                    this.postgresDs.initializeFrom(dataSource);
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } finally {
            dataSource.dispose();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    public void createDatabase(Configuration configuration) throws SQLException {
        try {
            super.createDatabase(configuration);
            Connection initConnection = super.initConnection();
            Statement createStatement = initConnection.createStatement();
            try {
                createStatement.executeUpdate("CREATE EXTENSION IF NOT EXISTS lshvector");
                createStatement.executeUpdate("CREATE TABLE vectable(id BIGINT UNIQUE,count INTEGER,vec lshvector)");
                createStatement.executeUpdate("CREATE INDEX vectable_vec_idx ON vectable USING gin (vec gin_lshvector_ops)");
                createVectorFunctions(createStatement);
                createStatement.executeUpdate("REVOKE ALL ON SCHEMA PUBLIC FROM PUBLIC");
                createStatement.executeUpdate("GRANT USAGE ON SCHEMA PUBLIC TO PUBLIC");
                createStatement.executeUpdate("GRANT SELECT ON ALL TABLES IN SCHEMA PUBLIC TO PUBLIC");
                createStatement.executeUpdate("GRANT USAGE ON ALL SEQUENCES IN SCHEMA PUBLIC TO PUBLIC");
                serverLoadWeights(initConnection);
                if (this.asynchronous) {
                    createStatement.executeUpdate("SET SESSION synchronous_commit TO OFF");
                } else {
                    createStatement.executeUpdate("SET SESSION synchronous_commit to ON");
                }
                if (createStatement != null) {
                    createStatement.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new SQLException("Could not create database: " + e.getMessage());
        }
    }

    private void dropIndex(Connection connection) throws SQLException {
        Statement createStatement = connection.createStatement();
        try {
            createStatement.execute("DROP INDEX vectable_vec_idx");
            if (createStatement != null) {
                createStatement.close();
            }
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void rebuildIndex(Connection connection) throws SQLException {
        Statement createStatement = connection.createStatement();
        try {
            ResultSet executeQuery = createStatement.executeQuery("SELECT lsh_reload()");
            try {
                createStatement.execute("SET maintenance_work_mem TO '2GB'");
                createStatement.execute("CREATE INDEX vectable_vec_idx ON vectable USING gin (vec gin_lshvector_ops)");
                if (executeQuery != null) {
                    executeQuery.close();
                }
                if (createStatement != null) {
                    createStatement.close();
                }
            } finally {
            }
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private int preWarm(Connection connection, int i, int i2, int i3) throws SQLException {
        ResultSet executeQuery;
        Statement createStatement = connection.createStatement();
        try {
            int i4 = -1;
            createStatement.execute("CREATE EXTENSION IF NOT EXISTS pg_prewarm");
            if (i != 0) {
                executeQuery = createStatement.executeQuery(i == 1 ? "SELECT pg_prewarm('vectable_vec_idx','read')" : "SELECT pg_prewarm('vectable_vec_idx')");
                try {
                    if (executeQuery.next()) {
                        i4 = executeQuery.getInt(1);
                        do {
                        } while (executeQuery.next());
                    }
                    if (executeQuery != null) {
                        executeQuery.close();
                    }
                } finally {
                }
            }
            if (i2 != 0) {
                ResultSet executeQuery2 = createStatement.executeQuery(i2 == 1 ? "SELECT pg_prewarm('vectable_id_key','read')" : "SELECT pg_prewarm('vectable_id_key')");
                do {
                    try {
                    } finally {
                    }
                } while (executeQuery2.next());
                if (executeQuery2 != null) {
                    executeQuery2.close();
                }
            }
            if (i3 != 0) {
                executeQuery = createStatement.executeQuery(i3 == 1 ? "SELECT pg_prewarm('vectable','read')" : "SELECT pg_prewarm('vectable')");
                do {
                    try {
                    } finally {
                        if (executeQuery != null) {
                            try {
                                executeQuery.close();
                            } catch (Throwable th) {
                                th.addSuppressed(th);
                            }
                        }
                    }
                } while (executeQuery.next());
                if (executeQuery != null) {
                    executeQuery.close();
                }
            }
            createStatement.execute("DROP EXTENSION pg_prewarm");
            int i5 = i4;
            if (createStatement != null) {
                createStatement.close();
            }
            return i5;
        } catch (Throwable th2) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th3) {
                    th2.addSuppressed(th3);
                }
            }
            throw th2;
        }
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected long storeSignatureRecord(SignatureRecord signatureRecord) throws SQLException {
        ResultSet executeQuery = getReusableStatement().executeQuery("SELECT insert_vec( '" + signatureRecord.getLSHVector().saveSQL() + "')");
        try {
            if (!executeQuery.next()) {
                throw new SQLException("Did not get vector id after insertion");
            }
            long j = executeQuery.getLong(1);
            if (executeQuery != null) {
                executeQuery.close();
            }
            return j;
        } catch (Throwable th) {
            if (executeQuery != null) {
                try {
                    executeQuery.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected int deleteVectors(long j, int i) throws SQLException {
        ResultSet executeQuery = getReusableStatement().executeQuery("SELECT remove_vec( " + Long.toString(j) + "," + Integer.toString(i) + ")");
        try {
            if (!executeQuery.next()) {
                throw new SQLException("Did not get result code after deletion");
            }
            int i2 = executeQuery.getInt(1);
            if (executeQuery != null) {
                executeQuery.close();
            }
            return i2;
        } catch (Throwable th) {
            if (executeQuery != null) {
                try {
                    executeQuery.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected int queryNearestVector(List<VectorResult> list, LSHVector lSHVector, double d, double d2, int i) throws SQLException {
        PreparedStatement prepareIfNeeded = this.selectNearestVectorStatement.prepareIfNeeded(() -> {
            return initConnection().prepareStatement("WITH const(cvec) AS (VALUES( lshvector_in( CAST( ? AS cstring) ) ) ), comp AS ( SELECT id,count,cvec,vec,lshvector_compare(cvec,vec) AS cfunc FROM const,vectable        WHERE cvec % vec) SELECT id,count,(comp.cfunc).sim,(comp.cfunc).sig,vec FROM comp WHERE (comp.cfunc).sim > ? AND (comp.cfunc).sig > ? ORDER BY (comp.cfunc).sim DESC LIMIT ?");
        });
        prepareIfNeeded.setString(1, lSHVector.saveSQL());
        prepareIfNeeded.setDouble(2, d);
        prepareIfNeeded.setDouble(3, d2);
        prepareIfNeeded.setInt(4, i);
        int i2 = 0;
        ResultSet executeQuery = prepareIfNeeded.executeQuery();
        while (executeQuery.next()) {
            try {
                VectorResult vectorResult = new VectorResult();
                list.add(vectorResult);
                vectorResult.vectorid = executeQuery.getLong(1);
                vectorResult.hitcount = executeQuery.getInt(2);
                vectorResult.sim = executeQuery.getDouble(3);
                vectorResult.signif = executeQuery.getDouble(4);
                try {
                    vectorResult.vec = ((WeightedLSHCosineVectorFactory) this.vectorFactory).restoreVectorFromSql(executeQuery.getString(5));
                    i2 += vectorResult.hitcount;
                } catch (IOException e) {
                    throw new SQLException(e.getMessage());
                }
            } catch (Throwable th) {
                if (executeQuery != null) {
                    try {
                        executeQuery.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        int i3 = i2;
        if (executeQuery != null) {
            executeQuery.close();
        }
        return i3;
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected void queryNearestVector(QueryNearestVector queryNearestVector) throws SQLException {
        ResponseNearestVector responseNearestVector = queryNearestVector.nearresponse;
        responseNearestVector.totalvec = 0;
        responseNearestVector.totalmatch = 0;
        responseNearestVector.uniquematch = 0;
        int i = queryNearestVector.vectormax;
        if (i == 0) {
            i = 2000000;
        }
        Iterator<FunctionDescription> listAllFunctions = queryNearestVector.manage.listAllFunctions();
        while (listAllFunctions.hasNext()) {
            FunctionDescription next = listAllFunctions.next();
            SignatureRecord signatureRecord = next.getSignatureRecord();
            if (signatureRecord != null) {
                LSHVector lSHVector = signatureRecord.getLSHVector();
                if (((WeightedLSHCosineVectorFactory) this.vectorFactory).getSelfSignificance(lSHVector) >= queryNearestVector.signifthresh) {
                    responseNearestVector.totalvec++;
                    ArrayList arrayList = new ArrayList();
                    queryNearestVector(arrayList, lSHVector, queryNearestVector.thresh, queryNearestVector.signifthresh, i);
                    if (!arrayList.isEmpty()) {
                        SimilarityVectorResult similarityVectorResult = new SimilarityVectorResult(next);
                        similarityVectorResult.addNotes(arrayList);
                        responseNearestVector.totalmatch += similarityVectorResult.getTotalCount();
                        if (similarityVectorResult.getTotalCount() == 1) {
                            responseNearestVector.uniquematch++;
                        }
                        responseNearestVector.result.add(similarityVectorResult);
                    }
                }
            }
        }
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    protected VectorResult queryVectorId(long j) throws SQLException {
        PreparedStatement prepareIfNeeded = this.selectVectorByRowIdStatement.prepareIfNeeded(() -> {
            return initConnection().prepareStatement("SELECT id,count,vec FROM vectable WHERE id = ?");
        });
        prepareIfNeeded.setLong(1, j);
        ResultSet executeQuery = prepareIfNeeded.executeQuery();
        try {
            if (!executeQuery.next()) {
                throw new SQLException("Bad vectable rowid");
            }
            try {
                VectorResult vectorResult = new VectorResult();
                vectorResult.vectorid = executeQuery.getLong(1);
                vectorResult.hitcount = executeQuery.getInt(2);
                vectorResult.vec = ((WeightedLSHCosineVectorFactory) this.vectorFactory).restoreVectorFromSql(executeQuery.getString(3));
                if (executeQuery != null) {
                    executeQuery.close();
                }
                return vectorResult;
            } catch (IOException e) {
                throw new SQLException(e.getMessage());
            }
        } catch (Throwable th) {
            if (executeQuery != null) {
                try {
                    executeQuery.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase, ghidra.features.bsim.query.FunctionDatabase
    public String getUserName() {
        return this.postgresDs.getUserName();
    }

    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase, ghidra.features.bsim.query.FunctionDatabase
    public void setUserName(String str) {
        if (this.postgresDs.getStatus() == FunctionDatabase.Status.Ready) {
            throw new IllegalStateException("Connection has already been established");
        }
        this.postgresDs.setPreferredUserName(str);
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [ghidra.features.bsim.query.protocol.QueryResponseRecord] */
    @Override // ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase
    public QueryResponseRecord doQuery(BSimQuery<?> bSimQuery, Connection connection) throws SQLException, LSHException, FunctionDatabase.DatabaseNonFatalException {
        if (bSimQuery instanceof PrewarmRequest) {
            fdbPrewarm((PrewarmRequest) bSimQuery, connection);
        } else if (bSimQuery instanceof PasswordChange) {
            fdbPasswordChange((PasswordChange) bSimQuery, connection);
        } else {
            if (!(bSimQuery instanceof AdjustVectorIndex)) {
                return super.doQuery(bSimQuery, connection);
            }
            fdbAdjustVectorIndex((AdjustVectorIndex) bSimQuery, connection);
        }
        return bSimQuery.getResponse();
    }

    private void fdbAdjustVectorIndex(AdjustVectorIndex adjustVectorIndex, Connection connection) throws SQLException {
        ResponseAdjustIndex responseAdjustIndex = adjustVectorIndex.adjustresponse;
        responseAdjustIndex.success = false;
        if (adjustVectorIndex.doRebuild) {
            rebuildIndex(connection);
        } else {
            dropIndex(connection);
        }
        responseAdjustIndex.success = true;
    }

    private void fdbPrewarm(PrewarmRequest prewarmRequest, Connection connection) throws SQLException {
        prewarmRequest.prewarmresponse.blockCount = preWarm(connection, prewarmRequest.mainIndexConfig, prewarmRequest.secondaryIndexConfig, prewarmRequest.vectorTableConfig);
    }

    private void fdbPasswordChange(PasswordChange passwordChange, Connection connection) throws LSHException {
        ResponsePassword responsePassword = passwordChange.passwordResponse;
        if (passwordChange.username == null) {
            throw new LSHException("Missing username for password change");
        }
        if (passwordChange.newPassword == null || passwordChange.newPassword.length == 0) {
            throw new LSHException("No password provided");
        }
        responsePassword.changeSuccessful = true;
        responsePassword.errorMessage = null;
        try {
            changePassword(connection, passwordChange.username, passwordChange.newPassword);
        } catch (SQLException e) {
            responsePassword.changeSuccessful = false;
            responsePassword.errorMessage = e.getMessage();
        }
    }

    @Override // ghidra.features.bsim.query.SQLFunctionDatabase
    public String formatBitAndSQL(String str, String str2) {
        return "(" + str + " & " + str2 + ")";
    }

    static {
        Logger.getLogger("org.postgresql.Driver").setLevel(Level.FINEST);
    }
}
