package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import com.google.common.collect.ArrayListMultimap;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseEdge;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseVertex;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/haplotypecaller/graphs/AdaptiveChainPruner.class */
public class AdaptiveChainPruner<V extends BaseVertex, E extends BaseEdge> extends ChainPruner<V, E> {
    private final double initialErrorProbability;
    private final double logOddsThreshold;
    private final double seedingLogOddsThreshold;
    private final int maxUnprunedVariants;

    public AdaptiveChainPruner(double d, double d2, double d3, int i) {
        ParamUtils.isPositive(d, "Must have positive error probability");
        this.initialErrorProbability = d;
        this.logOddsThreshold = d2;
        this.seedingLogOddsThreshold = d3;
        this.maxUnprunedVariants = i;
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.ChainPruner
    protected Collection<Path<V, E>> chainsToRemove(List<Path<V, E>> list) {
        if (list.isEmpty()) {
            return Collections.emptyList();
        }
        return (Collection) likelyErrorChains(list, list.get(0).getGraph(), likelyErrorChains(list, r0, this.initialErrorProbability).stream().mapToInt(path -> {
            return path.getLastEdge().getMultiplicity();
        }).sum() / list.stream().mapToInt(path2 -> {
            return path2.getEdges().stream().mapToInt((v0) -> {
                return v0.getMultiplicity();
            }).sum();
        }).sum()).stream().filter(path3 -> {
            return !path3.getEdges().stream().anyMatch((v0) -> {
                return v0.isRef();
            });
        }).collect(Collectors.toList());
    }

    private Collection<Path<V, E>> likelyErrorChains(List<Path<V, E>> list, BaseGraph<V, E> baseGraph, double d) {
        Map map = (Map) list.stream().collect(Collectors.toMap(path -> {
            return path;
        }, path2 -> {
            return chainLogOdds(path2, baseGraph, d);
        }));
        ArrayListMultimap create = ArrayListMultimap.create();
        ArrayListMultimap create2 = ArrayListMultimap.create();
        ArrayListMultimap create3 = ArrayListMultimap.create();
        for (Path<V, E> path3 : list) {
            if (((Double) ((Pair) map.get(path3)).getRight()).doubleValue() >= this.logOddsThreshold || path3.getEdges().get(0).isRef()) {
                create2.put(path3.getLastVertex(), path3);
            }
            if (((Double) ((Pair) map.get(path3)).getLeft()).doubleValue() >= this.logOddsThreshold || path3.getEdges().get(0).isRef()) {
                create3.put(path3.getFirstVertex(), path3);
            }
            if (((Double) ((Pair) map.get(path3)).getRight()).doubleValue() >= this.seedingLogOddsThreshold && ((Double) ((Pair) map.get(path3)).getLeft()).doubleValue() >= this.seedingLogOddsThreshold) {
                create.put(path3.getFirstVertex(), path3);
                create.put(path3.getLastVertex(), path3);
            }
        }
        PriorityQueue priorityQueue = new PriorityQueue(Comparator.comparingDouble(pair -> {
            return -((Double) pair.getRight()).doubleValue();
        }).thenComparing(pair2 -> {
            return ((Path) pair2.getLeft()).getFirstVertex().getSequence();
        }, BaseUtils.BASES_COMPARATOR));
        priorityQueue.add(ImmutablePair.of(getMaxWeightChain(list), Double.valueOf(Double.POSITIVE_INFINITY)));
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (BaseVertex baseVertex : create.keySet()) {
            if (create.get(baseVertex).size() > 2) {
                create3.get(baseVertex).forEach(path4 -> {
                    priorityQueue.add(ImmutablePair.of(path4, ((Pair) map.get(path4)).getLeft()));
                });
                create2.get(baseVertex).forEach(path5 -> {
                    priorityQueue.add(ImmutablePair.of(path5, ((Pair) map.get(path5)).getRight()));
                });
                linkedHashSet.add(baseVertex);
            }
        }
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        HashSet hashSet = new HashSet();
        MutableInt mutableInt = new MutableInt(0);
        while (!priorityQueue.isEmpty() && mutableInt.intValue() <= this.maxUnprunedVariants) {
            Path path6 = (Path) ((Pair) priorityQueue.poll()).getLeft();
            if (linkedHashSet2.add(path6)) {
                boolean z = !hashSet.add(path6.getFirstVertex());
                if (z) {
                    mutableInt.increment();
                }
                if (!z || mutableInt.intValue() <= this.maxUnprunedVariants) {
                    for (BaseVertex baseVertex2 : Arrays.asList(path6.getFirstVertex(), path6.getLastVertex())) {
                        if (!linkedHashSet.contains(baseVertex2)) {
                            create3.get(baseVertex2).forEach(path7 -> {
                                priorityQueue.add(ImmutablePair.of(path7, ((Pair) map.get(path7)).getLeft()));
                            });
                            create2.get(baseVertex2).forEach(path8 -> {
                                priorityQueue.add(ImmutablePair.of(path8, ((Pair) map.get(path8)).getRight()));
                            });
                            linkedHashSet.add(baseVertex2);
                        }
                    }
                }
            }
        }
        return (Collection) list.stream().filter(path9 -> {
            return !linkedHashSet2.contains(path9);
        }).collect(Collectors.toSet());
    }

    private Path<V, E> getMaxWeightChain(Collection<Path<V, E>> collection) {
        return collection.stream().max(Comparator.comparingInt(path -> {
            return path.getEdges().stream().mapToInt((v0) -> {
                return v0.getMultiplicity();
            }).max().orElse(0);
        }).thenComparingInt((v0) -> {
            return v0.length();
        }).thenComparing(path2 -> {
            return path2.getFirstVertex().getSequence();
        }, BaseUtils.BASES_COMPARATOR)).get();
    }

    private Pair<Double, Double> chainLogOdds(Path<V, E> path, BaseGraph<V, E> baseGraph, double d) {
        int sumIntFunction = MathUtils.sumIntFunction(baseGraph.outgoingEdgesOf(path.getFirstVertex()), (v0) -> {
            return v0.getMultiplicity();
        });
        int sumIntFunction2 = MathUtils.sumIntFunction(baseGraph.incomingEdgesOf(path.getLastVertex()), (v0) -> {
            return v0.getMultiplicity();
        });
        int multiplicity = path.getEdges().get(0).getMultiplicity();
        int multiplicity2 = path.getLastEdge().getMultiplicity();
        return ImmutablePair.of(Double.valueOf(baseGraph.isSource(path.getFirstVertex()) ? 0.0d : Mutect2Engine.logLikelihoodRatio(sumIntFunction - multiplicity, multiplicity, d)), Double.valueOf(baseGraph.isSink(path.getLastVertex()) ? 0.0d : Mutect2Engine.logLikelihoodRatio(sumIntFunction2 - multiplicity2, multiplicity2, d)));
    }
}
