package ghidra.features.bsim.query.client;

import db.buffers.LocalBufferFile;
import generic.lsh.vector.LSHVector;
import generic.lsh.vector.LSHVectorFactory;
import ghidra.features.bsim.query.FunctionDatabase;
import ghidra.features.bsim.query.LSHException;
import ghidra.features.bsim.query.client.tables.ExeTable;
import ghidra.features.bsim.query.description.DatabaseInformation;
import ghidra.features.bsim.query.description.DescriptionManager;
import ghidra.features.bsim.query.description.ExecutableRecord;
import ghidra.features.bsim.query.description.FunctionDescription;
import ghidra.features.bsim.query.description.VectorResult;
import ghidra.features.bsim.query.protocol.ExeSpecifier;
import ghidra.features.bsim.query.protocol.QueryExeInfo;
import ghidra.features.bsim.query.protocol.QueryInfo;
import ghidra.features.bsim.query.protocol.QueryName;
import ghidra.features.bsim.query.protocol.QueryNearestVector;
import ghidra.features.bsim.query.protocol.QueryVectorId;
import ghidra.features.bsim.query.protocol.QueryVectorMatch;
import ghidra.features.bsim.query.protocol.ResponseExe;
import ghidra.features.bsim.query.protocol.ResponseInfo;
import ghidra.features.bsim.query.protocol.ResponseName;
import ghidra.features.bsim.query.protocol.ResponseNearestVector;
import ghidra.features.bsim.query.protocol.ResponseVectorId;
import ghidra.features.bsim.query.protocol.ResponseVectorMatch;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;

/* loaded from: input_file:ghidra/features/bsim/query/client/ExecutableComparison.class */
public class ExecutableComparison {
    private FunctionDatabase database;
    private ExecutableScorer scorer;
    private String singleMd5;
    private TreeSet<Long> baseIds;
    private TreeSet<Long> queriedIds;
    private LSHVectorFactory vectorFactory;
    private int hitCountThreshold;
    private int maxHitCount;
    private int exceedCount;
    private TaskMonitor monitor;

    /* loaded from: input_file:ghidra/features/bsim/query/client/ExecutableComparison$Count.class */
    public static class Count {
        public int value = 0;
    }

    public ExecutableComparison(FunctionDatabase functionDatabase, int i, TaskMonitor taskMonitor) throws LSHException {
        this.database = functionDatabase;
        this.singleMd5 = null;
        this.hitCountThreshold = i;
        this.monitor = taskMonitor;
        if (this.monitor == null) {
            this.monitor = TaskMonitor.DUMMY;
        }
        this.baseIds = null;
        this.queriedIds = null;
        DatabaseInformation pullConnectionInfo = pullConnectionInfo();
        this.scorer = new ExecutableScorer();
        this.scorer.transferSettings(pullConnectionInfo);
    }

    public ExecutableComparison(FunctionDatabase functionDatabase, int i, String str, ScoreCaching scoreCaching, TaskMonitor taskMonitor) throws LSHException {
        this.database = functionDatabase;
        this.singleMd5 = str;
        this.hitCountThreshold = i;
        this.monitor = taskMonitor;
        if (this.monitor == null) {
            this.monitor = TaskMonitor.DUMMY;
        }
        this.baseIds = null;
        this.queriedIds = null;
        DatabaseInformation pullConnectionInfo = pullConnectionInfo();
        this.scorer = new ExecutableScorerSingle(scoreCaching);
        this.scorer.transferSettings(pullConnectionInfo);
        addExecutable(this.singleMd5);
    }

    public int getMaxHitCount() {
        return this.maxHitCount;
    }

    public int getExceedCount() {
        return this.exceedCount;
    }

    public boolean isConfigured() {
        return this.scorer.simThreshold > 0.0d;
    }

    private DatabaseInformation pullConnectionInfo() throws LSHException {
        if (!this.database.initialize()) {
            throw new LSHException("Unable to connect to server");
        }
        ResponseInfo execute = new QueryInfo().execute(this.database);
        if (execute == null) {
            throw new LSHException(this.database.getLastError().message);
        }
        this.vectorFactory = this.database.getLSHVectorFactory();
        return execute.info;
    }

    private ExecutableRecord lookupExecutable(String str) throws LSHException {
        QueryName queryName = new QueryName();
        queryName.spec = new ExeSpecifier();
        queryName.spec.exemd5 = str;
        queryName.maxfunc = 1;
        queryName.fillinCallgraph = false;
        queryName.fillinCategories = false;
        queryName.fillinSigs = false;
        ResponseName execute = queryName.execute(this.database);
        if (execute == null) {
            throw new LSHException(this.database.getLastError().message);
        }
        if (execute.manage.numExecutables() != 1) {
            throw new LSHException("Could not find executable");
        }
        return execute.manage.getExecutableRecordSet().first();
    }

    private void pullVectorsForExe(ExeSpecifier exeSpecifier, Map<Long, Count> map) {
        QueryName queryName = new QueryName();
        queryName.spec = exeSpecifier;
        queryName.maxfunc = 100000;
        queryName.fillinCallgraph = false;
        queryName.fillinCategories = false;
        queryName.fillinSigs = false;
        Iterator<FunctionDescription> listAllFunctions = queryName.execute(this.database).manage.listAllFunctions();
        if (map == null) {
            while (listAllFunctions.hasNext()) {
                this.baseIds.add(Long.valueOf(listAllFunctions.next().getVectorId()));
            }
        } else {
            while (listAllFunctions.hasNext()) {
                map.computeIfAbsent(Long.valueOf(listAllFunctions.next().getVectorId()), l -> {
                    return new Count();
                }).value++;
            }
        }
    }

    private void pullVectorsForScoringSet() throws CancelledException {
        this.baseIds = new TreeSet<>();
        this.queriedIds = new TreeSet<>();
        if (this.scorer instanceof ExecutableScorerSingle) {
            ExeSpecifier exeSpecifier = new ExeSpecifier();
            exeSpecifier.exemd5 = this.singleMd5;
            pullVectorsForExe(exeSpecifier, null);
            return;
        }
        this.monitor.setMessage("Accumulating vector ids");
        TreeSet<ExecutableRecord> executableRecordSet = this.scorer.executableSet.getExecutableRecordSet();
        this.monitor.initialize(executableRecordSet.size());
        ExeSpecifier exeSpecifier2 = new ExeSpecifier();
        Iterator<ExecutableRecord> it = executableRecordSet.iterator();
        while (it.hasNext()) {
            exeSpecifier2.exemd5 = it.next().getMd5();
            pullVectorsForExe(exeSpecifier2, null);
            this.monitor.checkCancelled();
            this.monitor.incrementProgress(1L);
        }
    }

    private QueryNearestVector buildVectorQuery(LSHVector lSHVector, double d) throws LSHException {
        QueryNearestVector queryNearestVector = new QueryNearestVector();
        queryNearestVector.manage.attachSignature(queryNearestVector.manage.newFunctionDescription(LocalBufferFile.PRESAVE_FILE_PREFIX, 4096L, queryNearestVector.manage.newExecutableRecord("bbbbaaaabbbbaaaabbbbaaaabbbbaaaa", null, null, null, null, null, null, null)), queryNearestVector.manage.newSignature(lSHVector, 1));
        queryNearestVector.manage.transferSettings(this.scorer.executableSet);
        queryNearestVector.thresh = d;
        return queryNearestVector;
    }

    private VectorResult buildSeedVector(Long l) throws LSHException {
        QueryVectorId queryVectorId = new QueryVectorId();
        queryVectorId.vectorIds.add(l);
        ResponseVectorId execute = queryVectorId.execute(this.database);
        if (execute == null || execute.vectorResults.size() != 1) {
            throw new LSHException("Could not locate vector by id");
        }
        return execute.vectorResults.get(0);
    }

    private VectorResult queryVectorForCluster(TreeMap<Long, VectorResult> treeMap, double d) throws LSHException {
        Map.Entry<Long, VectorResult> pollFirstEntry = treeMap.pollFirstEntry();
        VectorResult value = pollFirstEntry.getValue();
        this.baseIds.remove(pollFirstEntry.getKey());
        this.queriedIds.add(pollFirstEntry.getKey());
        if (value == null) {
            value = buildSeedVector(pollFirstEntry.getKey());
        }
        ResponseNearestVector execute = buildVectorQuery(value.vec, d).execute(this.database);
        if (execute == null || execute.result.size() != 1) {
            throw new LSHException("Could not perform query on vector");
        }
        Iterator<VectorResult> it = execute.result.get(0).iterator();
        while (it.hasNext()) {
            VectorResult next = it.next();
            Long valueOf = Long.valueOf(next.vectorid);
            if (!this.queriedIds.contains(valueOf)) {
                treeMap.put(valueOf, next);
            }
        }
        return value;
    }

    private int buildCluster(List<VectorResult> list, double d, double d2) throws LSHException {
        TreeMap<Long, VectorResult> treeMap = new TreeMap<>();
        treeMap.put(this.baseIds.pollFirst(), null);
        int i = 0;
        while (!treeMap.isEmpty()) {
            VectorResult queryVectorForCluster = queryVectorForCluster(treeMap, d);
            if (d2 < this.vectorFactory.getSelfSignificance(queryVectorForCluster.vec)) {
                list.add(queryVectorForCluster);
                i += queryVectorForCluster.hitcount;
            }
        }
        return i;
    }

    private List<DescriptionManager> vectorToFunctions(List<VectorResult> list) throws LSHException {
        ArrayList arrayList = new ArrayList(list.size());
        for (int i = 0; i < list.size(); i++) {
            VectorResult vectorResult = list.get(i);
            QueryVectorMatch queryVectorMatch = new QueryVectorMatch();
            queryVectorMatch.fillinCategories = false;
            queryVectorMatch.max = vectorResult.hitcount + 10;
            queryVectorMatch.vectorIds.add(Long.valueOf(vectorResult.vectorid));
            ResponseVectorMatch execute = queryVectorMatch.execute(this.database);
            if (execute == null) {
                throw new LSHException(this.database.getLastError().message);
            }
            this.scorer.labelAndFilter(execute.manage);
            arrayList.add(execute.manage);
        }
        return arrayList;
    }

    public ExecutableScorer getScorer() {
        return this.scorer;
    }

    public void addExecutable(String str) throws LSHException {
        this.scorer.addExecutable(lookupExecutable(str));
    }

    public void addAllExecutables(int i) throws LSHException {
        ResponseExe execute = new QueryExeInfo(i, null, null, null, null, ExeTable.ExeTableOrderColumn.MD5, false).execute(this.database);
        if (execute == null) {
            throw new LSHException(this.database.getLastError().message);
        }
        Iterator<ExecutableRecord> it = execute.records.iterator();
        while (it.hasNext()) {
            this.scorer.addExecutable(it.next());
        }
    }

    public void performScoring() throws LSHException, CancelledException {
        this.maxHitCount = 0;
        this.exceedCount = 0;
        this.scorer.populateExecutableIndex();
        if (this.singleMd5 != null) {
            this.scorer.setSingleExecutable(this.singleMd5);
        }
        pullVectorsForScoringSet();
        this.scorer.initializeScores();
        this.monitor.setMessage("Processing similar functions");
        this.monitor.initialize(this.baseIds.size());
        if (this.scorer.simThreshold < 0.0d) {
            throw new LSHException("No thresholds have been established");
        }
        while (!this.baseIds.isEmpty()) {
            ArrayList arrayList = new ArrayList();
            int buildCluster = buildCluster(arrayList, this.scorer.simThreshold, this.scorer.sigThreshold);
            if (buildCluster != 0) {
                if (buildCluster > this.maxHitCount) {
                    this.maxHitCount = buildCluster;
                }
                if (this.scorer.checkPreliminaryPairThreshold(buildCluster, this.hitCountThreshold)) {
                    if (this.scorer.scoreCluster(this.vectorFactory, vectorToFunctions(arrayList), arrayList, buildCluster, this.hitCountThreshold)) {
                        this.monitor.checkCancelled();
                        this.monitor.setProgress(r0 - this.baseIds.size());
                    } else {
                        this.exceedCount++;
                    }
                } else {
                    this.exceedCount++;
                }
            }
        }
        this.baseIds = null;
        this.queriedIds = null;
    }

    public void resetThresholds(double d, double d2) throws LSHException {
        this.scorer.resetStorage(d, d2);
    }

    public void fillinSelfScores() throws LSHException, CancelledException {
        if (this.scorer instanceof ExecutableScorerSingle) {
            ExecutableScorerSingle executableScorerSingle = (ExecutableScorerSingle) this.scorer;
            ArrayList arrayList = new ArrayList();
            executableScorerSingle.prefetchSelfScores(arrayList);
            int size = arrayList.size();
            if (size == 0) {
                return;
            }
            if (size == 1 && arrayList.get(0).getMd5().equals(this.singleMd5)) {
                return;
            }
            double sigThreshold = executableScorerSingle.getSigThreshold();
            this.monitor.setMessage("Generating self-significance scores");
            this.monitor.initialize(size);
            ExeSpecifier exeSpecifier = new ExeSpecifier();
            for (ExecutableRecord executableRecord : arrayList) {
                if (!executableRecord.getMd5().equals(this.singleMd5)) {
                    TreeMap treeMap = new TreeMap();
                    exeSpecifier.exemd5 = executableRecord.getMd5();
                    pullVectorsForExe(exeSpecifier, treeMap);
                    double d = 0.0d;
                    Iterator it = treeMap.entrySet().iterator();
                    while (it.hasNext()) {
                        double selfSignificance = this.vectorFactory.getSelfSignificance(buildSeedVector((Long) ((Map.Entry) it.next()).getKey()).vec);
                        if (selfSignificance >= sigThreshold) {
                            d += selfSignificance * ((Count) r0.getValue()).value;
                        }
                    }
                    this.scorer.commitSelfScore(executableRecord.getMd5(), (float) d);
                    this.monitor.checkCancelled();
                    this.monitor.incrementProgress(1L);
                }
            }
        }
    }
}
