package ghidra.feature.vt.api.correlator.program;

import generic.DominantPair;
import generic.lsh.vector.LSHCosineVectorAccum;
import generic.lsh.vector.VectorCompare;
import ghidra.feature.vt.api.main.VTAssociation;
import ghidra.feature.vt.api.main.VTAssociationStatus;
import ghidra.feature.vt.api.main.VTAssociationType;
import ghidra.feature.vt.api.main.VTMatch;
import ghidra.feature.vt.api.main.VTMatchInfo;
import ghidra.feature.vt.api.main.VTMatchSet;
import ghidra.feature.vt.api.main.VTScore;
import ghidra.feature.vt.api.main.VTSession;
import ghidra.feature.vt.api.util.VTAbstractProgramCorrelator;
import ghidra.framework.options.ToolOptions;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.listing.CodeUnit;
import ghidra.program.model.listing.CodeUnitIterator;
import ghidra.program.model.listing.Data;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.FunctionManager;
import ghidra.program.model.listing.Instruction;
import ghidra.program.model.listing.Listing;
import ghidra.program.model.listing.Program;
import ghidra.program.model.symbol.Reference;
import ghidra.program.model.symbol.ReferenceIterator;
import ghidra.program.model.symbol.ReferenceManager;
import ghidra.util.datastruct.Counter;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.collections4.map.LazyMap;

/* loaded from: input_file:ghidra/feature/vt/api/correlator/program/VTAbstractReferenceProgramCorrelator.class */
public abstract class VTAbstractReferenceProgramCorrelator extends VTAbstractProgramCorrelator {
    private static final int MAX_DEPTH = 30;
    private static final int TOP_N = 5;
    private static final double DIFFERENTIAL = 0.2d;
    private static final double EQUALS_EPSILON = 1.0E-5d;
    private static final Comparator<VTMatchInfo> SCORE_COMPARATOR = (vTMatchInfo, vTMatchInfo2) -> {
        return vTMatchInfo2.getSimilarityScore().compareTo(vTMatchInfo.getSimilarityScore());
    };
    private String correlatorName;
    private Map<Address, LSHCosineVectorAccum> srcVectorsByAddress;
    private Map<Address, LSHCosineVectorAccum> destVectorsByAddress;
    private Program sourceProgram;
    private Program destinationProgram;
    private Listing sourceListing;
    private Listing destinationListing;

    public VTAbstractReferenceProgramCorrelator(Program program, AddressSetView addressSetView, Program program2, AddressSetView addressSetView2, String str, ToolOptions toolOptions) {
        super(program, addressSetView, program2, addressSetView2, toolOptions);
        this.correlatorName = str;
        this.sourceProgram = program;
        this.destinationProgram = program2;
        this.sourceListing = program.getListing();
        this.destinationListing = program2.getListing();
    }

    @Override // ghidra.feature.vt.api.main.VTProgramCorrelator
    public String getName() {
        return this.correlatorName;
    }

    @Override // ghidra.feature.vt.api.util.VTAbstractProgramCorrelator
    protected void doCorrelate(VTMatchSet vTMatchSet, TaskMonitor taskMonitor) throws CancelledException {
        taskMonitor.setMessage("Finding reference features");
        extractReferenceFeatures(vTMatchSet, taskMonitor);
        taskMonitor.setMessage("Finding destination functions");
        findDestinations(vTMatchSet, taskMonitor);
    }

    private void findDestinations(VTMatchSet vTMatchSet, TaskMonitor taskMonitor) throws CancelledException {
        taskMonitor.initialize(this.destVectorsByAddress.size());
        for (Map.Entry<Address, LSHCosineVectorAccum> entry : this.destVectorsByAddress.entrySet()) {
            taskMonitor.checkCancelled();
            taskMonitor.incrementProgress(1L);
            Function functionAt = this.destinationListing.getFunctionAt(entry.getKey());
            LSHCosineVectorAccum value = entry.getValue();
            HashMap hashMap = new HashMap();
            for (Map.Entry<Address, LSHCosineVectorAccum> entry2 : this.srcVectorsByAddress.entrySet()) {
                Address key = entry2.getKey();
                LSHCosineVectorAccum value2 = entry2.getValue();
                VectorCompare vectorCompare = new VectorCompare();
                DominantPair<Double, VectorCompare> dominantPair = new DominantPair<>(Double.valueOf(value.compare(value2, vectorCompare)), vectorCompare);
                if (value.compare(value2, vectorCompare) > 0.0d) {
                    hashMap.put(key, dominantPair);
                }
            }
            for (VTMatchInfo vTMatchInfo : transform(vTMatchSet, functionAt, value, hashMap, taskMonitor)) {
                if (vTMatchInfo != null) {
                    vTMatchSet.addMatch(vTMatchInfo);
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<VTMatchInfo> transform(VTMatchSet vTMatchSet, Function function, LSHCosineVectorAccum lSHCosineVectorAccum, Map<Address, DominantPair<Double, VectorCompare>> map, TaskMonitor taskMonitor) throws CancelledException {
        boolean z = getOptions().getBoolean(VTAbstractReferenceProgramCorrelatorFactory.REFINE_RESULTS, true);
        double d = getOptions().getDouble(VTAbstractReferenceProgramCorrelatorFactory.CONFIDENCE_THRESHOLD, 1.0d);
        double d2 = getOptions().getDouble(VTAbstractReferenceProgramCorrelatorFactory.SIMILARITY_THRESHOLD, 0.5d);
        Address entryPoint = function.getEntryPoint();
        int numAddresses = (int) function.getBody().getNumAddresses();
        List<VTMatchInfo> arrayList = new ArrayList();
        for (Map.Entry<Address, DominantPair<Double, VectorCompare>> entry : map.entrySet()) {
            taskMonitor.checkCancelled();
            Address key = entry.getKey();
            double doubleValue = ((Double) entry.getValue().first).doubleValue();
            VectorCompare vectorCompare = (VectorCompare) entry.getValue().second;
            vectorCompare.fillOut();
            double d3 = vectorCompare.dotproduct;
            if (doubleValue >= d2 && !Double.isNaN(doubleValue) && d3 >= d) {
                double d4 = d3 * 10.0d;
                VTMatchInfo vTMatchInfo = new VTMatchInfo(vTMatchSet);
                Function functionAt = this.sourceListing.getFunctionAt(key);
                Address entryPoint2 = functionAt.getEntryPoint();
                int numAddresses2 = (int) functionAt.getBody().getNumAddresses();
                vTMatchInfo.setSimilarityScore(new VTScore(doubleValue));
                vTMatchInfo.setConfidenceScore(new VTScore(d4));
                vTMatchInfo.setSourceLength(numAddresses2);
                vTMatchInfo.setDestinationLength(numAddresses);
                vTMatchInfo.setSourceAddress(entryPoint2);
                vTMatchInfo.setDestinationAddress(entryPoint);
                vTMatchInfo.setTag(null);
                vTMatchInfo.setAssociationType(VTAssociationType.FUNCTION);
                arrayList.add(vTMatchInfo);
            }
        }
        if (z) {
            arrayList = refine(arrayList);
        }
        return arrayList;
    }

    private List<VTMatchInfo> refine(List<VTMatchInfo> list) {
        Collections.sort(list, SCORE_COMPARATOR);
        List<VTMatchInfo> subList = list.subList(0, Math.min(6, list.size()));
        if (subList.size() > 1) {
            double score = subList.get(0).getSimilarityScore().getScore();
            int i = 1;
            int i2 = 1;
            while (true) {
                if (i2 >= subList.size()) {
                    break;
                }
                double score2 = subList.get(i2).getSimilarityScore().getScore();
                if (score2 > score - EQUALS_EPSILON) {
                    i--;
                    break;
                }
                i++;
                score = score2;
                i2++;
            }
            subList = subList.subList(0, i);
        }
        List<VTMatchInfo> subList2 = subList.subList(0, Math.min(5, subList.size()));
        if (subList2.size() > 1) {
            double score3 = subList2.get(0).getSimilarityScore().getScore();
            int size = subList2.size();
            int i3 = 1;
            while (true) {
                if (i3 >= subList2.size()) {
                    break;
                }
                if (subList2.get(i3).getSimilarityScore().getScore() < score3 - 0.2d) {
                    size = i3;
                    break;
                }
                i3++;
            }
            subList2 = subList2.subList(0, size);
        }
        return subList2;
    }

    private void accumulateFunctionReferences(int i, Set<Function> set, Program program, Address address) {
        Address[] functionThunkAddresses;
        if (i >= 30) {
            return;
        }
        FunctionManager functionManager = program.getFunctionManager();
        Function functionAt = functionManager.getFunctionAt(address);
        if (functionAt != null && (functionThunkAddresses = functionAt.getFunctionThunkAddresses()) != null) {
            for (Address address2 : functionThunkAddresses) {
                accumulateFunctionReferences(i + 1, set, program, address2);
            }
        }
        if (address.isStackAddress() || address.isRegisterAddress()) {
            return;
        }
        ReferenceManager referenceManager = program.getReferenceManager();
        Listing listing = program.getListing();
        ReferenceIterator referencesTo = referenceManager.getReferencesTo(address);
        while (referencesTo.hasNext()) {
            Address fromAddress = referencesTo.next().getFromAddress();
            CodeUnit codeUnitAt = listing.getCodeUnitAt(fromAddress);
            if (codeUnitAt instanceof Instruction) {
                Function functionContaining = functionManager.getFunctionContaining(fromAddress);
                if (functionContaining != null) {
                    if (functionContaining.isThunk()) {
                        accumulateFunctionReferences(i + 1, set, program, functionContaining.getEntryPoint());
                    } else {
                        set.add(functionContaining);
                    }
                }
            } else if (codeUnitAt instanceof Data) {
                accumulateFunctionReferences(i + 1, set, program, fromAddress);
            }
        }
    }

    protected abstract boolean isExpectedRefType(VTAssociationType vTAssociationType);

    protected abstract boolean isExpectedRefType(Reference reference);

    private void extractReferenceFeatures(VTMatchSet vTMatchSet, TaskMonitor taskMonitor) throws CancelledException {
        this.srcVectorsByAddress = LazyMap.lazyMap(new HashMap(), address -> {
            return new LSHCosineVectorAccum();
        });
        this.destVectorsByAddress = LazyMap.lazyMap(new HashMap(), address2 -> {
            return new LSHCosineVectorAccum();
        });
        FunctionManager functionManager = this.sourceProgram.getFunctionManager();
        FunctionManager functionManager2 = this.destinationProgram.getFunctionManager();
        int functionCount = functionManager.getFunctionCount();
        int functionCount2 = functionManager2.getFunctionCount();
        Collection<VTMatchSet> matchSets = getMatchSets(vTMatchSet.getSession(), new Counter());
        taskMonitor.initialize(r0.count());
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Iterator<VTMatchSet> it = matchSets.iterator();
        while (it.hasNext()) {
            for (VTMatch vTMatch : it.next().getMatches()) {
                taskMonitor.checkCancelled();
                taskMonitor.incrementProgress(1L);
                accumulateMatchFunctionReferences(hashMap, hashMap2, vTMatch);
            }
        }
        taskMonitor.setMessage("Adding ACCEPTED matches to feature vectors.");
        int i = 1;
        for (VTMatch vTMatch2 : hashMap.keySet()) {
            taskMonitor.checkCancelled();
            taskMonitor.incrementProgress(1L);
            if (!hashMap.get(vTMatch2).isEmpty()) {
                double sqrt = Math.sqrt(-Math.log((new HashSet(hashMap.get(vTMatch2)).size() + new HashSet(hashMap2.get(vTMatch2)).size()) / (functionCount + functionCount2)));
                Iterator<Function> it2 = hashMap.get(vTMatch2).iterator();
                while (it2.hasNext()) {
                    this.srcVectorsByAddress.get(it2.next().getEntryPoint()).addHash(i, sqrt);
                }
                Iterator<Function> it3 = hashMap2.get(vTMatch2).iterator();
                while (it3.hasNext()) {
                    this.destVectorsByAddress.get(it3.next().getEntryPoint()).addHash(i, sqrt);
                }
                i++;
            }
        }
        updateSourceAndDestinationVectors(i, functionManager, functionManager2, taskMonitor);
    }

    private Collection<VTMatchSet> getMatchSets(VTSession vTSession, Counter counter) {
        HashMap hashMap = new HashMap();
        for (VTMatchSet vTMatchSet : vTSession.getMatchSets()) {
            String name = vTMatchSet.getProgramCorrelatorInfo().getName();
            if (!name.equals(this.correlatorName) && (!hashMap.containsKey(name) || vTMatchSet.getID() >= ((VTMatchSet) hashMap.get(name)).getID())) {
                hashMap.put(name, vTMatchSet);
                counter.add(vTMatchSet.getMatchCount());
            }
        }
        return hashMap.values();
    }

    private void accumulateMatchFunctionReferences(Map<VTMatch, Set<Function>> map, Map<VTMatch, Set<Function>> map2, VTMatch vTMatch) {
        VTAssociation association = vTMatch.getAssociation();
        Address sourceAddress = association.getSourceAddress();
        Address destinationAddress = association.getDestinationAddress();
        if (isExpectedRefType(association.getType()) && association.getStatus() == VTAssociationStatus.ACCEPTED) {
            HashSet hashSet = new HashSet();
            accumulateFunctionReferences(0, hashSet, this.sourceProgram, sourceAddress);
            if (hashSet.isEmpty()) {
                return;
            }
            HashSet hashSet2 = new HashSet();
            accumulateFunctionReferences(0, hashSet2, this.destinationProgram, destinationAddress);
            if (hashSet2.isEmpty()) {
                return;
            }
            map.put(vTMatch, hashSet);
            map2.put(vTMatch, hashSet2);
        }
    }

    private void updateSourceAndDestinationVectors(int i, FunctionManager functionManager, FunctionManager functionManager2, TaskMonitor taskMonitor) {
        taskMonitor.setMessage("Adding unmatched references to feature vectors.");
        double sqrt = Math.sqrt(-Math.log(0.5d));
        for (Address address : this.srcVectorsByAddress.keySet()) {
            int countFunctionRefs = countFunctionRefs(this.sourceProgram, address);
            LSHCosineVectorAccum lSHCosineVectorAccum = this.srcVectorsByAddress.get(address);
            int numEntries = lSHCosineVectorAccum.numEntries();
            for (int i2 = 0; i2 < countFunctionRefs - numEntries; i2++) {
                lSHCosineVectorAccum.addHash(i, sqrt);
                i++;
            }
        }
        for (Address address2 : this.destVectorsByAddress.keySet()) {
            int countFunctionRefs2 = countFunctionRefs(this.destinationProgram, address2);
            LSHCosineVectorAccum lSHCosineVectorAccum2 = this.destVectorsByAddress.get(address2);
            int numEntries2 = lSHCosineVectorAccum2.numEntries();
            for (int i3 = 0; i3 < countFunctionRefs2 - numEntries2; i3++) {
                lSHCosineVectorAccum2.addHash(i, sqrt);
                i++;
            }
        }
    }

    private int countFunctionRefs(Program program, Address address) {
        CodeUnitIterator codeUnits = program.getListing().getCodeUnits(program.getFunctionManager().getFunctionAt(address).getBody(), true);
        int i = 0;
        while (codeUnits.hasNext()) {
            for (Reference reference : codeUnits.next().getReferencesFrom()) {
                if (isExpectedRefType(reference)) {
                    i++;
                }
            }
        }
        return i;
    }
}
