package ghidra.feature.vt.api;

import generic.concurrent.ConcurrentQ;
import generic.concurrent.ConcurrentQBuilder;
import generic.concurrent.GThreadPool;
import generic.concurrent.QCallback;
import generic.concurrent.QResult;
import generic.lsh.LSHMemoryModel;
import generic.lsh.vector.LSHVectorFactory;
import generic.lsh.vector.VectorCompare;
import ghidra.feature.vt.api.NeighborGenerator;
import ghidra.feature.vt.api.main.VTAssociation;
import ghidra.feature.vt.api.main.VTAssociationManager;
import ghidra.feature.vt.api.main.VTAssociationStatus;
import ghidra.feature.vt.api.main.VTAssociationType;
import ghidra.feature.vt.api.main.VTMatchSet;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.Program;
import ghidra.util.Msg;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.commons.collections4.MultiValuedMap;
import org.apache.commons.collections4.multimap.HashSetValuedHashMap;

/* loaded from: input_file:ghidra/feature/vt/api/BSimProgramCorrelatorMatching.class */
public class BSimProgramCorrelatorMatching {
    private SortedSet<PotentialPair> implications = new TreeSet();
    private FunctionNodeContainer sourceNodes;
    private FunctionNodeContainer destNodes;
    private LSHVectorFactory vectorFactory;
    private LinkedList<FunctionPair> matches;
    private Set<FunctionPair> seeds;
    private List<FunctionPair> discoveredMatches;
    private double confThreshold;
    private double impThreshold;
    private double potentialSimThreshold;
    private LSHMemoryModel memoryModel;
    private boolean useNamespaceNeighbors;
    private static final Comparator<FunctionPair> CONF_COMPARATOR = new Comparator<FunctionPair>() { // from class: ghidra.feature.vt.api.BSimProgramCorrelatorMatching.1
        @Override // java.util.Comparator
        public int compare(FunctionPair functionPair, FunctionPair functionPair2) {
            return Double.compare(functionPair2.getConfResult(), functionPair.getConfResult());
        }
    };

    /* loaded from: input_file:ghidra/feature/vt/api/BSimProgramCorrelatorMatching$MatchingCallback.class */
    private class MatchingCallback implements QCallback<FunctionNode, List<FunctionPair>> {
        private BinningSystem sourceBinning;
        private double simThreshold;

        MatchingCallback(BinningSystem binningSystem, double d) {
            this.sourceBinning = binningSystem;
            this.simThreshold = d;
        }

        @Override // generic.concurrent.QCallback
        public List<FunctionPair> process(FunctionNode functionNode, TaskMonitor taskMonitor) throws Exception {
            taskMonitor.checkCancelled();
            if (functionNode == null || functionNode.getVector() == null) {
                taskMonitor.incrementProgress(1L);
                return null;
            }
            LinkedList linkedList = new LinkedList();
            findSimilarNodes(linkedList, functionNode, taskMonitor);
            taskMonitor.incrementProgress(1L);
            return linkedList;
        }

        private void findSimilarNodes(List<FunctionPair> list, FunctionNode functionNode, TaskMonitor taskMonitor) throws CancelledException {
            Set<FunctionNode> lookup = this.sourceBinning.lookup(functionNode);
            VectorCompare vectorCompare = new VectorCompare();
            for (FunctionNode functionNode2 : lookup) {
                taskMonitor.checkCancelled();
                double compare = functionNode2.getVector().compare(functionNode.getVector(), vectorCompare);
                if (compare >= this.simThreshold) {
                    list.add(new FunctionPair(functionNode2, functionNode, compare, BSimProgramCorrelatorMatching.this.vectorFactory.calculateSignificance(vectorCompare)));
                }
            }
        }
    }

    public BSimProgramCorrelatorMatching(FunctionNodeContainer functionNodeContainer, FunctionNodeContainer functionNodeContainer2, LSHVectorFactory lSHVectorFactory, double d, double d2, double d3, boolean z, LSHMemoryModel lSHMemoryModel) {
        this.sourceNodes = functionNodeContainer;
        this.destNodes = functionNodeContainer2;
        this.vectorFactory = lSHVectorFactory;
        this.confThreshold = d;
        this.impThreshold = d2;
        this.potentialSimThreshold = d3;
        this.useNamespaceNeighbors = z;
        this.memoryModel = lSHMemoryModel;
    }

    private void acceptMatch(FunctionPair functionPair) {
        FunctionNode sourceNode = functionPair.getSourceNode();
        FunctionNode destNode = functionPair.getDestNode();
        sourceNode.setAcceptedMatch(true);
        destNode.setAcceptedMatch(true);
        this.matches.add(functionPair);
        Iterator<Map.Entry<FunctionNode, FunctionPair>> associateIterator = sourceNode.getAssociateIterator();
        while (associateIterator.hasNext()) {
            associateIterator.next().getKey().removeAssociate(sourceNode);
        }
        Iterator<Map.Entry<FunctionNode, FunctionPair>> associateIterator2 = destNode.getAssociateIterator();
        while (associateIterator2.hasNext()) {
            associateIterator2.next().getKey().removeAssociate(destNode);
        }
        sourceNode.clearAssociates();
        destNode.clearAssociates();
    }

    public void discoverPotentialMatches(TaskMonitor taskMonitor) throws Exception {
        BinningSystem binningSystem = new BinningSystem(this.memoryModel);
        taskMonitor.setMessage("Binning source functions...");
        taskMonitor.initialize(this.sourceNodes.size());
        binningSystem.add(this.sourceNodes.iterator(), taskMonitor);
        taskMonitor.setMessage("Zealously over-pairing matches...");
        taskMonitor.initialize(this.destNodes.size());
        ConcurrentQ build = new ConcurrentQBuilder().setThreadPool(GThreadPool.getPrivateThreadPool("BSimProgramCorrelatorMatching")).setCollectResults(true).setMonitor(taskMonitor).build(new MatchingCallback(binningSystem, this.potentialSimThreshold));
        build.add((Iterator) this.destNodes.iterator());
        try {
            Collection<QResult> waitForResults = build.waitForResults();
            build.dispose();
            this.discoveredMatches = new LinkedList();
            for (QResult qResult : waitForResults) {
                taskMonitor.checkCancelled();
                List<FunctionPair> list = (List) qResult.getResult();
                if (list != null) {
                    for (FunctionPair functionPair : list) {
                        taskMonitor.checkCancelled();
                        if (functionPair != null) {
                            FunctionNode sourceNode = functionPair.getSourceNode();
                            FunctionNode destNode = functionPair.getDestNode();
                            sourceNode.addAssociate(destNode, functionPair);
                            destNode.addAssociate(sourceNode, functionPair);
                            this.discoveredMatches.add(functionPair);
                        }
                    }
                }
            }
        } catch (Throwable th) {
            build.dispose();
            throw th;
        }
    }

    private static int findIndexMatchingThreshold(ArrayList<FunctionPair> arrayList, double d) {
        int i = 0;
        int size = arrayList.size() - 1;
        while (i < size) {
            int i2 = ((i + size) + 1) / 2;
            if (arrayList.get(i2).getConfResult() < d) {
                size = i2 - 1;
            } else {
                i = i2;
            }
        }
        return i;
    }

    private void chooseSeeds(TaskMonitor taskMonitor) throws CancelledException {
        taskMonitor.setMessage("Generating seeds...");
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashSetValuedHashMap hashSetValuedHashMap = new HashSetValuedHashMap();
        HashSetValuedHashMap hashSetValuedHashMap2 = new HashSetValuedHashMap();
        HashSetValuedHashMap hashSetValuedHashMap3 = new HashSetValuedHashMap();
        HashSetValuedHashMap hashSetValuedHashMap4 = new HashSetValuedHashMap();
        for (FunctionPair functionPair : this.discoveredMatches) {
            hashSetValuedHashMap3.put(functionPair.getSourceNode(), functionPair);
            hashSetValuedHashMap4.put(functionPair.getDestNode(), functionPair);
        }
        this.discoveredMatches = null;
        int size = hashSetValuedHashMap3.size();
        if (size == 0) {
            return;
        }
        boolean z = true;
        double d = 0.5d;
        while (true) {
            double d2 = d;
            if (!z) {
                break;
            }
            taskMonitor.checkCancelled();
            Collection<V> values = hashSetValuedHashMap3.values();
            taskMonitor.initialize(values.size());
            for (V v : values) {
                taskMonitor.checkCancelled();
                taskMonitor.incrementProgress(1L);
                if (!hasConflicts(v, hashSetValuedHashMap3, hashSetValuedHashMap4)) {
                    arrayList.add(v);
                    hashSet.add(v.getSourceNode());
                    hashSet2.add(v.getDestNode());
                } else if (!hashSet.contains(v.getSourceNode()) && !hashSet2.contains(v.getDestNode())) {
                    double min = Math.min(v.getSourceNode().getChildren().size(), v.getDestNode().getChildren().size());
                    double max = Math.max(v.getSourceNode().getChildren().size(), v.getDestNode().getChildren().size());
                    double d3 = max == 0.0d ? 0.0d : min / max;
                    double len = v.getSourceNode().getLen() / v.getDestNode().getLen();
                    if (Math.min(len, 1.0d / len) > d2 && d3 > d2) {
                        hashSetValuedHashMap.put(v.getSourceNode(), v);
                        hashSetValuedHashMap2.put(v.getDestNode(), v);
                    }
                }
            }
            hashSetValuedHashMap3 = hashSetValuedHashMap;
            hashSetValuedHashMap4 = hashSetValuedHashMap2;
            z = size != values.size();
            size = hashSetValuedHashMap.values().size();
            hashSetValuedHashMap = new HashSetValuedHashMap();
            hashSetValuedHashMap2 = new HashSetValuedHashMap();
            d = (2.0d + d2) / 3.0d;
        }
        if (arrayList.isEmpty()) {
            return;
        }
        Collections.sort(arrayList, CONF_COMPARATOR);
        double confResult = ((FunctionPair) arrayList.get(0)).getConfResult();
        if (confResult < this.confThreshold) {
            Msg.warn(this, "Initial value of seed confidence too high (" + this.confThreshold + ")...resetting seed confidence to " + this);
            this.confThreshold = confResult;
        }
        int findIndexMatchingThreshold = findIndexMatchingThreshold(arrayList, this.confThreshold);
        for (int i = 0; i < findIndexMatchingThreshold + 1; i++) {
            this.seeds.add((FunctionPair) arrayList.get(i));
        }
    }

    private static boolean hasConflicts(FunctionPair functionPair, MultiValuedMap<FunctionNode, FunctionPair> multiValuedMap, MultiValuedMap<FunctionNode, FunctionPair> multiValuedMap2) {
        Collection<FunctionPair> collection = multiValuedMap.get(functionPair.getSourceNode());
        if (collection != null && collection.size() > 1) {
            return true;
        }
        Collection<FunctionPair> collection2 = multiValuedMap2.get(functionPair.getDestNode());
        return collection2 != null && collection2.size() > 1;
    }

    public boolean generateSeeds(VTMatchSet vTMatchSet, boolean z, TaskMonitor taskMonitor) throws CancelledException {
        this.seeds = new HashSet();
        if (z) {
            findAcceptedSeeds(vTMatchSet, taskMonitor);
        }
        chooseSeeds(taskMonitor);
        return !this.seeds.isEmpty();
    }

    private NeighborGenerator[] buildNeighborGenerators(int i) {
        ArrayList arrayList = new ArrayList();
        if (i == 0) {
            arrayList.add(new NeighborGenerator.Children(this.vectorFactory, this.impThreshold));
            arrayList.add(new NeighborGenerator.Parents(this.vectorFactory, this.impThreshold));
            if (this.useNamespaceNeighbors) {
                arrayList.add(new NamespaceNeighborhood(this.vectorFactory, this.impThreshold, this.sourceNodes, this.destNodes));
            }
        } else {
            arrayList.add(new NeighborGenerator.Children(this.vectorFactory, this.impThreshold));
            arrayList.add(new NeighborGenerator.Parents(this.vectorFactory, this.impThreshold));
            arrayList.add(new NeighborGenerator.GrandChildren(this.vectorFactory, this.impThreshold));
            arrayList.add(new NeighborGenerator.Siblings(this.vectorFactory, this.impThreshold));
            arrayList.add(new NeighborGenerator.Spouses(this.vectorFactory, this.impThreshold));
            arrayList.add(new NeighborGenerator.GrandParents(this.vectorFactory, this.impThreshold));
            if (this.useNamespaceNeighbors) {
                arrayList.add(new NamespaceNeighborhood(this.vectorFactory, this.impThreshold, this.sourceNodes, this.destNodes));
            }
        }
        NeighborGenerator[] neighborGeneratorArr = new NeighborGenerator[arrayList.size()];
        arrayList.toArray(neighborGeneratorArr);
        return neighborGeneratorArr;
    }

    public List<FunctionPair> doMatching(TaskMonitor taskMonitor) throws CancelledException {
        this.matches = new LinkedList<>();
        for (int i = 0; i < 2; i++) {
            taskMonitor.checkCancelled();
            NeighborGenerator[] buildNeighborGenerators = buildNeighborGenerators(i);
            if (i == 0) {
                taskMonitor.setMessage("Matching round 1...");
                taskMonitor.initialize(this.seeds.size());
                for (FunctionPair functionPair : this.seeds) {
                    taskMonitor.checkCancelled();
                    taskMonitor.incrementProgress(1L);
                    acceptMatch(functionPair);
                    PotentialPair analyze = analyze(functionPair, buildNeighborGenerators);
                    if (analyze != null) {
                        this.implications.add(analyze);
                    }
                }
                this.seeds = null;
            } else {
                this.implications.clear();
                taskMonitor.setMessage("Matching round 2...");
                taskMonitor.initialize(this.matches.size());
                Iterator<FunctionPair> it = this.matches.iterator();
                while (it.hasNext()) {
                    FunctionPair next = it.next();
                    taskMonitor.checkCancelled();
                    taskMonitor.incrementProgress(1L);
                    PotentialPair analyze2 = analyze(next, buildNeighborGenerators);
                    if (analyze2 != null) {
                        this.implications.add(analyze2);
                    }
                }
            }
            taskMonitor.setMessage("Gathering matches for round " + (i + 1) + "...");
            int size = this.implications.size();
            taskMonitor.initialize(size + 1);
            do {
                taskMonitor.checkCancelled();
                int size2 = this.implications.size();
                if (size2 > size) {
                    size = size2;
                    taskMonitor.setMaximum(size + 1);
                }
                taskMonitor.setProgress((size - size2) + 1);
                if (size2 == 0) {
                    break;
                }
                PotentialPair last = this.implications.last();
                this.implications.remove(last);
                FunctionPair findEdge = last.getSource().findEdge(last.getDestination());
                if (findEdge != null) {
                    acceptMatch(findEdge);
                    PotentialPair analyze3 = analyze(findEdge, buildNeighborGenerators);
                    if (analyze3 != null) {
                        this.implications.add(analyze3);
                    }
                }
                PotentialPair analyze4 = analyze(last.getOrigin(), buildNeighborGenerators);
                if (analyze4 != null) {
                    this.implications.add(analyze4);
                }
                if (!this.implications.isEmpty()) {
                }
            } while (this.implications.last().getScore() >= this.impThreshold);
        }
        LinkedList linkedList = new LinkedList(this.matches);
        VectorCompare vectorCompare = new VectorCompare();
        taskMonitor.setMessage("Patching holes...");
        taskMonitor.initialize(this.matches.size());
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            FunctionPair functionPair2 = (FunctionPair) it2.next();
            taskMonitor.checkCancelled();
            taskMonitor.incrementProgress(1L);
            if (functionPair2.getSourceNode().getParents().size() == 1 && functionPair2.getDestNode().getParents().size() == 1) {
                FunctionNode next2 = functionPair2.getSourceNode().getParents().iterator().next();
                FunctionNode next3 = functionPair2.getDestNode().getParents().iterator().next();
                if (next2.findEdge(next3) == null && !next2.isAcceptedMatch() && !next3.isAcceptedMatch()) {
                    acceptMatch(new FunctionPair(next2, next3, next2.getVector().compare(next3.getVector(), vectorCompare), this.vectorFactory.calculateSignificance(vectorCompare)));
                }
            }
        }
        return this.matches;
    }

    private void findAcceptedSeeds(VTMatchSet vTMatchSet, TaskMonitor taskMonitor) throws CancelledException {
        FunctionNode functionNode;
        FunctionNode functionNode2;
        FunctionPair findEdge;
        taskMonitor.setMessage("Using accepted matches as seeds...");
        VTAssociationManager associationManager = vTMatchSet.getSession().getAssociationManager();
        taskMonitor.initialize(associationManager.getAssociationCount());
        List<VTAssociation> associations = associationManager.getAssociations();
        Program program = this.sourceNodes.getProgram();
        Program program2 = this.destNodes.getProgram();
        for (VTAssociation vTAssociation : associations) {
            taskMonitor.checkCancelled();
            if (vTAssociation.getType().equals(VTAssociationType.FUNCTION) && vTAssociation.getStatus() == VTAssociationStatus.ACCEPTED) {
                Address sourceAddress = vTAssociation.getSourceAddress();
                Function functionAt = program.getListing().getFunctionAt(sourceAddress);
                Address destinationAddress = vTAssociation.getDestinationAddress();
                Function functionAt2 = program2.getListing().getFunctionAt(destinationAddress);
                if (functionAt != null && functionAt2 != null && (functionNode = this.sourceNodes.get(sourceAddress)) != null && (functionNode2 = this.destNodes.get(destinationAddress)) != null && (findEdge = functionNode.findEdge(functionNode2)) != null) {
                    this.seeds.add(findEdge);
                }
            }
            taskMonitor.incrementProgress(1L);
        }
    }

    private PotentialPair analyze(FunctionPair functionPair, NeighborGenerator[] neighborGeneratorArr) {
        FunctionNode sourceNode = functionPair.getSourceNode();
        FunctionNode destNode = functionPair.getDestNode();
        double confResult = functionPair.getConfResult();
        double d = 0.0d;
        PotentialPair potentialPair = null;
        for (NeighborGenerator neighborGenerator : neighborGeneratorArr) {
            NeighborGenerator.NeighborhoodPair generate = neighborGenerator.generate(sourceNode, destNode);
            PotentialPair calculateBestNeighbor = calculateBestNeighbor(generate.srcNeighbors, generate.destNeighbors, confResult);
            if (calculateBestNeighbor.getScore() > d) {
                d = calculateBestNeighbor.getScore();
                potentialPair = calculateBestNeighbor;
            }
            PotentialPair calculateBestNeighbor2 = calculateBestNeighbor(generate.destNeighbors, generate.srcNeighbors, confResult);
            calculateBestNeighbor2.swap();
            if (calculateBestNeighbor2.getScore() > d) {
                d = calculateBestNeighbor2.getScore();
                potentialPair = calculateBestNeighbor2;
            }
        }
        if (potentialPair != null) {
            potentialPair.setOrigin(functionPair);
        }
        return potentialPair;
    }

    private static PotentialPair unconflictedPair(ArrayList<PotentialPair> arrayList, int i, int i2) {
        for (int i3 = i; i3 <= i2; i3++) {
            FunctionNode source = arrayList.get(i3).getSource();
            FunctionNode destination = arrayList.get(i3).getDestination();
            boolean z = true;
            for (int i4 = i; i4 <= i2; i4++) {
                if (i3 != i4) {
                    FunctionNode source2 = arrayList.get(i4).getSource();
                    FunctionNode destination2 = arrayList.get(i4).getDestination();
                    if (source == source2 || destination == destination2) {
                        z = false;
                        break;
                    }
                }
            }
            if (z) {
                return arrayList.get(i3);
            }
        }
        return null;
    }

    private static double adjustConfidenceScore(double d, FunctionNode functionNode, FunctionNode functionNode2) {
        int size = functionNode2.getChildren().size();
        double size2 = size == 0 ? 0.0d : functionNode.getChildren().size() / size;
        double min = Math.min(size2, 1.0d / size2);
        int size3 = functionNode2.getParents().size();
        double size4 = size3 == 0 ? 0.0d : functionNode.getParents().size() / size3;
        double len = functionNode.getLen() / functionNode2.getLen();
        return 0.25d * d * Math.min(len, 1.0d / len) * (1.0d + min) * (1.0d + Math.min(size4, 1.0d / size4));
    }

    private static PotentialPair findFirstUnconflictedPair(ArrayList<PotentialPair> arrayList) {
        Collections.sort(arrayList);
        int size = arrayList.size() - 1;
        while (true) {
            int i = size;
            if (i < 0) {
                return PotentialPair.EMPTY_PAIR;
            }
            double score = arrayList.get(i).getScore();
            int i2 = i - 1;
            while (i2 >= 0 && arrayList.get(i2).getScore() >= score) {
                i2--;
            }
            PotentialPair unconflictedPair = unconflictedPair(arrayList, i2 + 1, i);
            if (unconflictedPair != null) {
                return unconflictedPair;
            }
            size = i2;
        }
    }

    private PotentialPair calculateBestNeighbor(Set<FunctionNode> set, Set<FunctionNode> set2, double d) {
        ArrayList arrayList = new ArrayList();
        PotentialPair potentialPair = PotentialPair.EMPTY_PAIR;
        int i = 0;
        for (FunctionNode functionNode : set) {
            if (!functionNode.isAcceptedMatch()) {
                double d2 = 0.0d;
                double d3 = 0.0d;
                double d4 = 0.0d;
                FunctionNode functionNode2 = null;
                Iterator<Map.Entry<FunctionNode, FunctionPair>> associateIterator = functionNode.getAssociateIterator();
                while (associateIterator.hasNext()) {
                    Map.Entry<FunctionNode, FunctionPair> next = associateIterator.next();
                    FunctionNode key = next.getKey();
                    double confResult = next.getValue().getConfResult();
                    if (set2.contains(key)) {
                        double adjustConfidenceScore = adjustConfidenceScore(confResult, functionNode, key);
                        d3 += adjustConfidenceScore;
                        if (adjustConfidenceScore >= d2) {
                            d2 = adjustConfidenceScore;
                            functionNode2 = key;
                            d4 = confResult;
                        }
                    }
                }
                if (d3 > 0.0d) {
                    double size = ((set2.size() * (d4 + d)) * d2) / d3;
                    PotentialPair potentialPair2 = new PotentialPair(functionNode, functionNode2, size);
                    arrayList.add(potentialPair2);
                    if (size > potentialPair.getScore()) {
                        potentialPair = potentialPair2;
                        i = 1;
                    } else if (size == potentialPair.getScore()) {
                        i++;
                    }
                }
            }
        }
        return (i == 0 || potentialPair.getScore() == 0.0d) ? PotentialPair.EMPTY_PAIR : i == 1 ? potentialPair : findFirstUnconflictedPair(arrayList);
    }
}
