package org.broadinstitute.hellbender.tools.walkers.mutect;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.GenotypeBuilder;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.collections4.ListUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyBasedCallerUtils;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResultSet;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.CalledHaplotypes;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReferenceConfidenceUtils;
import org.broadinstitute.hellbender.tools.walkers.mutect.PerAlleleCollection;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.genotyper.AlleleList;
import org.broadinstitute.hellbender.utils.genotyper.LikelihoodMatrix;
import org.broadinstitute.hellbender.utils.genotyper.SampleList;
import org.broadinstitute.hellbender.utils.haplotype.EventMap;
import org.broadinstitute.hellbender.utils.haplotype.Haplotype;
import org.broadinstitute.hellbender.utils.read.Fragment;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/SomaticGenotypingEngine.class */
public class SomaticGenotypingEngine {
    private final M2ArgumentCollection MTAC;
    private final Set<String> normalSamples;
    final boolean hasNormal;
    protected VariantAnnotatorEngine annotationEngine;
    private final double refPseudocount = 1.0d;
    private final double altPseudocount;

    public SomaticGenotypingEngine(M2ArgumentCollection m2ArgumentCollection, Set<String> set, VariantAnnotatorEngine variantAnnotatorEngine) {
        this.MTAC = m2ArgumentCollection;
        this.altPseudocount = m2ArgumentCollection.minAF == 0.0d ? 1.0d : 1.0d - (Math.log(2.0d) / Math.log(m2ArgumentCollection.minAF));
        this.normalSamples = set;
        this.hasNormal = !set.isEmpty();
        this.annotationEngine = variantAnnotatorEngine;
    }

    public CalledHaplotypes callMutations(AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, AssemblyResultSet assemblyResultSet, ReferenceContext referenceContext, SimpleInterval simpleInterval, FeatureContext featureContext, List<VariantContext> list, SAMFileHeader sAMFileHeader, boolean z, boolean z2) {
        Utils.nonNull(alleleLikelihoods);
        Utils.validateArg(alleleLikelihoods.numberOfSamples() > 0, "likelihoods have no samples");
        Utils.nonNull(simpleInterval);
        List<Haplotype> alleles = alleleLikelihoods.alleles();
        List list2 = (List) EventMap.buildEventMapsForHaplotypes(alleles, assemblyResultSet.getFullReferenceWithPadding(), assemblyResultSet.getPaddedReferenceLoc(), this.MTAC.assemblerArgs.debugAssembly, this.MTAC.maxMnpDistance).stream().filter(num -> {
            return simpleInterval.getStart() <= num.intValue() && num.intValue() <= simpleInterval.getEnd();
        }).collect(Collectors.toList());
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        if (z) {
            AssemblyBasedCallerUtils.annotateReadLikelihoodsWithRegions(alleleLikelihoods, simpleInterval);
        }
        if (this.MTAC.likelihoodArgs.phredScaledGlobalReadMismappingRate > 0) {
            alleleLikelihoods.normalizeLikelihoods(NaturalLogUtils.qualToLogErrorProb(this.MTAC.likelihoodArgs.phredScaledGlobalReadMismappingRate));
        }
        AlleleLikelihoods<NEW_EVIDENCE_TYPE, Haplotype> groupEvidence = alleleLikelihoods.groupEvidence(this.MTAC.independentMates ? gATKRead -> {
            return gATKRead;
        } : (v0) -> {
            return v0.getName();
        }, Fragment::createAndAvoidFailure);
        Iterator it = list2.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            VariantContext makeMergedVariantContext = AssemblyBasedCallerUtils.makeMergedVariantContext(AssemblyBasedCallerUtils.getVariantContextsFromActiveHaplotypes(intValue, alleles, false));
            if (makeMergedVariantContext != null) {
                Map createAlleleMapper = AssemblyBasedCallerUtils.createAlleleMapper(makeMergedVariantContext, intValue, alleles);
                AlleleLikelihoods marginalize = groupEvidence.marginalize(createAlleleMapper);
                SimpleInterval expandWithinContig = new SimpleInterval((Locatable) makeMergedVariantContext).expandWithinContig(this.MTAC.informativeReadOverlapMargin, sAMFileHeader.getSequenceDictionary());
                expandWithinContig.getClass();
                marginalize.retainEvidence((v1) -> {
                    return r1.overlaps(v1);
                });
                if (z2) {
                    makeMergedVariantContext = ReferenceConfidenceUtils.addNonRefSymbolicAllele(makeMergedVariantContext);
                    marginalize.addNonReferenceAllele(Allele.NON_REF_ALLELE);
                }
                IntStream filter = IntStream.range(0, marginalize.numberOfSamples()).filter(i -> {
                    return !this.normalSamples.contains(marginalize.getSample(i));
                });
                marginalize.getClass();
                List list3 = (List) filter.mapToObj(marginalize::sampleMatrix).collect(Collectors.toList());
                AlleleList alleleList = (AlleleList) list3.get(0);
                PerAlleleCollection<Double> somaticLogOdds = somaticLogOdds(combinedLikelihoodMatrix(list3, alleleList));
                IntStream filter2 = IntStream.range(0, marginalize.numberOfSamples()).filter(i2 -> {
                    return this.normalSamples.contains(marginalize.getSample(i2));
                });
                marginalize.getClass();
                LikelihoodMatrix combinedLikelihoodMatrix = combinedLikelihoodMatrix((List) filter2.mapToObj(marginalize::sampleMatrix).collect(Collectors.toList()), alleleList);
                PerAlleleCollection<Double> diploidAltLogOdds = diploidAltLogOdds(combinedLikelihoodMatrix);
                PerAlleleCollection<Double> somaticLogOdds2 = somaticLogOdds(combinedLikelihoodMatrix);
                Set<Allele> allelesConsistentWithGivenAlleles = AssemblyBasedCallerUtils.getAllelesConsistentWithGivenAlleles(list, makeMergedVariantContext);
                List list4 = (List) makeMergedVariantContext.getAlternateAlleles().stream().filter(allele -> {
                    return allelesConsistentWithGivenAlleles.contains(allele) || ((Double) somaticLogOdds.getAlt(allele)).doubleValue() > this.MTAC.getEmissionLogOdds();
                }).collect(Collectors.toList());
                if (list4.stream().filter(allele2 -> {
                    return allelesConsistentWithGivenAlleles.contains(allele2) || !this.hasNormal || this.MTAC.genotypeGermlineSites || ((Double) diploidAltLogOdds.getAlt(allele2)).doubleValue() > MathUtils.log10ToLog(this.MTAC.normalLog10Odds);
                }).count() != 0) {
                    List<Allele> union = ListUtils.union(Arrays.asList(makeMergedVariantContext.getReference()), list4);
                    VariantContextBuilder attribute = new VariantContextBuilder(makeMergedVariantContext).alleles(union).attributes(getNegativeLogPopulationAFAnnotation(featureContext.getValues(this.MTAC.germlineResource, intValue), list4, this.MTAC.getDefaultAlleleFrequency())).attribute(GATKVCFConstants.TUMOR_LOG_10_ODDS_KEY, list4.stream().mapToDouble(allele3 -> {
                        return MathUtils.logToLog10(((Double) somaticLogOdds.getAlt(allele3)).doubleValue());
                    }).toArray());
                    if (this.hasNormal) {
                        attribute.attribute(GATKVCFConstants.NORMAL_ARTIFACT_LOG_10_ODDS_KEY, Arrays.stream(somaticLogOdds2.asDoubleArray(list4)).map(d -> {
                            return -MathUtils.logToLog10(d);
                        }).toArray());
                        attribute.attribute(GATKVCFConstants.NORMAL_LOG_10_ODDS_KEY, Arrays.stream(diploidAltLogOdds.asDoubleArray(list4)).map(MathUtils::logToLog10).toArray());
                    }
                    if (!featureContext.getValues(this.MTAC.pon, makeMergedVariantContext.getStart()).isEmpty()) {
                        attribute.attribute("PON", true);
                    }
                    addGenotypes(marginalize, union, attribute);
                    VariantContext make = attribute.make();
                    VariantContext trimAlleles = GATKVariantContextUtils.trimAlleles(make, true, true);
                    List alleles2 = trimAlleles.getAlleles();
                    List alleles3 = make.getAlleles();
                    Map map = (Map) IntStream.range(0, trimAlleles.getNAlleles()).boxed().collect(Collectors.toMap(num2 -> {
                        return (Allele) alleles2.get(num2.intValue());
                    }, num3 -> {
                        return Arrays.asList((Allele) alleles3.get(num3.intValue()));
                    }));
                    AlleleLikelihoods marginalize2 = marginalize.marginalize(map);
                    AlleleLikelihoods<GATKRead, B> marginalize3 = alleleLikelihoods.marginalize(createAlleleMapper);
                    expandWithinContig.getClass();
                    marginalize3.retainEvidence((v1) -> {
                        return r1.overlaps(v1);
                    });
                    if (z2) {
                        marginalize3.addNonReferenceAllele(Allele.NON_REF_ALLELE);
                    }
                    VariantContext annotateContext = this.annotationEngine.annotateContext(trimAlleles, featureContext, referenceContext, marginalize3.marginalize(map), variantAnnotation -> {
                        return true;
                    });
                    if (z) {
                        AssemblyBasedCallerUtils.annotateReadLikelihoodsWithSupportedAlleles(trimAlleles, marginalize2, (v0) -> {
                            return v0.getReads();
                        });
                    }
                    Stream stream = make.getAlleles().stream();
                    createAlleleMapper.getClass();
                    Stream filter3 = stream.map((v1) -> {
                        return r1.get(v1);
                    }).filter((v0) -> {
                        return Objects.nonNull(v0);
                    });
                    hashSet.getClass();
                    filter3.forEach((v1) -> {
                        r1.addAll(v1);
                    });
                    arrayList.add(annotateContext);
                }
            }
        }
        List<VariantContext> phaseCalls = AssemblyBasedCallerUtils.phaseCalls(arrayList, hashSet);
        int size = phaseCalls.size();
        return new CalledHaplotypes((List) phaseCalls.stream().map(variantContext -> {
            return new VariantContextBuilder(variantContext).attribute(GATKVCFConstants.EVENT_COUNT_IN_HAPLOTYPE_KEY, Integer.valueOf(size)).make();
        }).collect(Collectors.toList()), hashSet);
    }

    private double[] makePriorPseudocounts(int i) {
        return new IndexRange(0, i).mapToDouble(i2 -> {
            if (i2 == 0) {
                return 1.0d;
            }
            return this.altPseudocount;
        });
    }

    protected <EVIDENCE extends Locatable> PerAlleleCollection<Double> somaticLogOdds(LikelihoodMatrix<EVIDENCE, Allele> likelihoodMatrix) {
        int size = likelihoodMatrix.alleles().size() - 1;
        if (likelihoodMatrix.alleles().contains(Allele.NON_REF_ALLELE) && !likelihoodMatrix.alleles().get(size).equals(Allele.NON_REF_ALLELE)) {
            throw new IllegalStateException("<NON_REF> must be last in the allele list.");
        }
        double logEvidence = likelihoodMatrix.evidenceCount() == 0 ? 0.0d : SomaticLikelihoodsEngine.logEvidence(getAsRealMatrix(likelihoodMatrix), makePriorPseudocounts(likelihoodMatrix.numberOfAlleles()));
        PerAlleleCollection<Double> perAlleleCollection = new PerAlleleCollection<>(PerAlleleCollection.Type.ALT_ONLY);
        int refIndex = getRefIndex(likelihoodMatrix);
        IntStream.range(0, likelihoodMatrix.numberOfAlleles()).filter(i -> {
            return i != refIndex;
        }).forEach(i2 -> {
            Allele allele = likelihoodMatrix.getAllele(i2);
            SubsettedLikelihoodMatrix excludingAllele = SubsettedLikelihoodMatrix.excludingAllele(likelihoodMatrix, allele);
            perAlleleCollection.setAlt(allele, Double.valueOf(logEvidence - (excludingAllele.evidenceCount() == 0 ? 0.0d : SomaticLikelihoodsEngine.logEvidence(getAsRealMatrix(excludingAllele), makePriorPseudocounts(excludingAllele.numberOfAlleles())))));
        });
        return perAlleleCollection;
    }

    private <EVIDENCE extends Locatable> void addGenotypes(AlleleLikelihoods<EVIDENCE, Allele> alleleLikelihoods, List<Allele> list, VariantContextBuilder variantContextBuilder) {
        variantContextBuilder.genotypes((List) IntStream.range(0, alleleLikelihoods.numberOfSamples()).mapToObj(i -> {
            String sample = alleleLikelihoods.getSample(i);
            SubsettedLikelihoodMatrix subsettedLikelihoodMatrix = new SubsettedLikelihoodMatrix(alleleLikelihoods.sampleMatrix(i), list);
            double[] effectiveCounts = getEffectiveCounts(subsettedLikelihoodMatrix);
            double[] mapToDouble = new IndexRange(0, subsettedLikelihoodMatrix.numberOfAlleles()).mapToDouble(i -> {
                return 1.0d;
            });
            double[] normalizeSumToOne = MathUtils.normalizeSumToOne(subsettedLikelihoodMatrix.evidenceCount() == 0 ? mapToDouble : SomaticLikelihoodsEngine.alleleFractionsPosterior(getAsRealMatrix(subsettedLikelihoodMatrix), mapToDouble));
            return new GenotypeBuilder(sample, this.normalSamples.contains(sample) ? Collections.nCopies(2, subsettedLikelihoodMatrix.getAllele(getRefIndex(subsettedLikelihoodMatrix))) : subsettedLikelihoodMatrix.alleles()).AD(Arrays.stream(effectiveCounts).mapToInt(d -> {
                return (int) FastMath.round(d);
            }).toArray()).attribute("AF", Arrays.copyOfRange(normalizeSumToOne, 1, normalizeSumToOne.length)).make();
        }).collect(Collectors.toList()));
    }

    private static <EVIDENCE> double[] getEffectiveCounts(LikelihoodMatrix<EVIDENCE, Allele> likelihoodMatrix) {
        if (likelihoodMatrix.evidenceCount() == 0) {
            return new double[likelihoodMatrix.numberOfAlleles()];
        }
        RealMatrix asRealMatrix = getAsRealMatrix(likelihoodMatrix);
        return MathUtils.sumArrayFunction(0, asRealMatrix.getColumnDimension(), i -> {
            return NaturalLogUtils.normalizeFromLogToLinearSpace(asRealMatrix.getColumn(i));
        });
    }

    private <EVIDENCE extends Locatable> PerAlleleCollection<Double> diploidAltLogOdds(LikelihoodMatrix<EVIDENCE, Allele> likelihoodMatrix) {
        int refIndex = getRefIndex(likelihoodMatrix);
        int evidenceCount = likelihoodMatrix.evidenceCount();
        double sum = new IndexRange(0, evidenceCount).sum(i -> {
            return likelihoodMatrix.get(refIndex, i);
        });
        PerAlleleCollection<Double> perAlleleCollection = new PerAlleleCollection<>(PerAlleleCollection.Type.ALT_ONLY);
        IntStream.range(0, likelihoodMatrix.numberOfAlleles()).filter(i2 -> {
            return i2 != refIndex;
        }).forEach(i3 -> {
            perAlleleCollection.setAlt(likelihoodMatrix.getAllele(i3), Double.valueOf(sum - new IndexRange(0, evidenceCount).sum(i3 -> {
                return NaturalLogUtils.logSumExp(likelihoodMatrix.get(refIndex, i3), likelihoodMatrix.get(i3, i3)) + NaturalLogUtils.LOG_ONE_HALF;
            })));
        });
        return perAlleleCollection;
    }

    private <EVIDENCE> int getRefIndex(LikelihoodMatrix<EVIDENCE, Allele> likelihoodMatrix) {
        OptionalInt findFirst = IntStream.range(0, likelihoodMatrix.numberOfAlleles()).filter(i -> {
            return likelihoodMatrix.getAllele(i).isReference();
        }).findFirst();
        Utils.validateArg(findFirst.isPresent(), "No ref allele found in likelihoods");
        return findFirst.getAsInt();
    }

    public static <EVIDENCE> RealMatrix getAsRealMatrix(final LikelihoodMatrix<EVIDENCE, Allele> likelihoodMatrix) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(likelihoodMatrix.numberOfAlleles(), likelihoodMatrix.evidenceCount());
        array2DRowRealMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.walkers.mutect.SomaticGenotypingEngine.1
            public double visit(int i, int i2, double d) {
                return LikelihoodMatrix.this.get(i, i2);
            }
        });
        return array2DRowRealMatrix;
    }

    private static <EVIDENCE extends Locatable> LikelihoodMatrix<EVIDENCE, Allele> combinedLikelihoodMatrix(List<LikelihoodMatrix<EVIDENCE, Allele>> list, AlleleList<Allele> alleleList) {
        int i = 0;
        LikelihoodMatrix<EVIDENCE, Allele> sampleMatrix = new AlleleLikelihoods(SampleList.singletonSampleList("COMBINED"), alleleList, ImmutableMap.of("COMBINED", (List) list.stream().flatMap(likelihoodMatrix -> {
            return likelihoodMatrix.evidence().stream();
        }).collect(Collectors.toList()))).sampleMatrix(0);
        int numberOfAlleles = sampleMatrix.numberOfAlleles();
        for (LikelihoodMatrix<EVIDENCE, Allele> likelihoodMatrix2 : list) {
            int evidenceCount = likelihoodMatrix2.evidenceCount();
            for (int i2 = 0; i2 < evidenceCount; i2++) {
                for (int i3 = 0; i3 < numberOfAlleles; i3++) {
                    sampleMatrix.set(i3, i, likelihoodMatrix2.get(i3, i2));
                }
                i++;
            }
        }
        return sampleMatrix;
    }

    private <E> Optional<E> getForNormal(Supplier<E> supplier) {
        return this.hasNormal ? Optional.of(supplier.get()) : Optional.empty();
    }

    private static Map<String, Object> getNegativeLogPopulationAFAnnotation(List<VariantContext> list, List<Allele> list2, double d) {
        return ImmutableMap.of(GATKVCFConstants.POPULATION_AF_KEY, MathUtils.applyToArray(getGermlineAltAlleleFrequencies(list2, list.isEmpty() ? Optional.empty() : Optional.of(list.get(0)), d), d2 -> {
            return -Math.log10(d2);
        }));
    }

    @VisibleForTesting
    static double[] getGermlineAltAlleleFrequencies(List<Allele> list, Optional<VariantContext> optional, double d) {
        if (!optional.isPresent()) {
            return Doubles.toArray(Collections.nCopies(list.size(), Double.valueOf(d)));
        }
        List<Double> attributeAsDoubleList = Mutect2Engine.getAttributeAsDoubleList(optional.get(), "AF", d);
        return list.stream().mapToDouble(allele -> {
            VariantContext variantContext = (VariantContext) optional.get();
            OptionalInt findAny = IntStream.range(0, variantContext.getNAlleles() - 1).filter(i -> {
                return variantContext.getAlternateAllele(i).basesMatch(allele);
            }).findAny();
            return findAny.isPresent() ? ((Double) attributeAsDoubleList.get(findAny.getAsInt())).doubleValue() : d;
        }).toArray();
    }
}
