package org.broadinstitute.hellbender.tools.walkers.annotator;

import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.GenotypeBuilder;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.vcf.VCFCompoundHeaderLine;
import htsjdk.variant.vcf.VCFFormatHeaderLine;
import htsjdk.variant.vcf.VCFHeaderLineCount;
import htsjdk.variant.vcf.VCFHeaderLineType;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.tools.walkers.mutect.SomaticLikelihoodsEngine;
import org.broadinstitute.hellbender.tools.walkers.mutect.SubsettedLikelihoodMatrix;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.genotyper.LikelihoodMatrix;
import org.broadinstitute.hellbender.utils.help.HelpConstants;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;

@DocumentedFeature(groupName = HelpConstants.DOC_CAT_ANNOTATORS, groupSummary = HelpConstants.DOC_CAT_ANNOTATORS_SUMMARY, summary = "Total depth of coverage per sample and over all samples (DP)")
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/annotator/AllelePseudoDepth.class */
public final class AllelePseudoDepth implements GenotypeAnnotation {

    @Argument(fullName = "dirichlet-prior-pseudo-count", doc = "Pseudo-count used as prior for all alleles. The default is 1.0 resulting in a flat prior")
    @Advanced
    public double prior = 1.0d;

    @Argument(fullName = "dirichlet-keep-prior-in-count", doc = "By default we don't keep the prior use in the output counts ase it makes it easier to interpretthis quantity as the number of supporting reads specially in low depth sites. We this toggled the prior is included")
    public boolean keepPriorInCount = false;

    @Argument(fullName = "pseudo-count-weight-decay-rate", doc = "A what rate the weight of a read decreases base on its informativeness; e.g. 1.0 is linear decay (default), 2.0 is for quadratic decay", minValue = 0.0d)
    public double weightDecay = 1.0d;
    private Int2ObjectMap<double[]> priorPseudoCounts = new Int2ObjectOpenHashMap();
    private static DecimalFormat DEPTH_FORMAT = new DecimalFormat("#.##");
    private static DecimalFormat FRACTION_FORMAT = new DecimalFormat("#.####");
    public static final VCFFormatHeaderLine DEPTH_HEADER_LINE = new VCFFormatHeaderLine(GATKVCFConstants.PSEUDO_DEPTH_KEY, VCFHeaderLineCount.R, VCFHeaderLineType.Float, "Allele depth based on Dirichlet posterior pseudo-counts");
    public static final VCFFormatHeaderLine FRACTION_HEADER_LINE = new VCFFormatHeaderLine("DF", VCFHeaderLineCount.R, VCFHeaderLineType.Float, "Allele Fraction based on Dirichlet posterior pseudo-counts");
    private static final List<String> KEYS = Arrays.asList(GATKVCFConstants.PSEUDO_DEPTH_KEY, "DF");
    private static final List<VCFCompoundHeaderLine> HEADER_LINES = Arrays.asList(DEPTH_HEADER_LINE, FRACTION_HEADER_LINE);

    @Override // org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotation
    public List<VCFCompoundHeaderLine> getDescriptions() {
        return HEADER_LINES;
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotation
    public List<String> getKeyNames() {
        return KEYS;
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.annotator.GenotypeAnnotation
    public void annotate(ReferenceContext referenceContext, VariantContext variantContext, Genotype genotype, GenotypeBuilder genotypeBuilder, AlleleLikelihoods<GATKRead, Allele> alleleLikelihoods) {
        double[] alleleFractionsPosterior;
        if (alleleLikelihoods == null) {
            return;
        }
        List alleles = variantContext.getAlleles();
        if (alleles.size() <= 1) {
            return;
        }
        double[] composePriorPseudoCounts = composePriorPseudoCounts(alleles.size());
        LikelihoodMatrix<GATKRead, Allele> sampleMatrix = alleleLikelihoods.sampleMatrix(alleleLikelihoods.indexOfSample(genotype.getSampleName()));
        LikelihoodMatrix<GATKRead, Allele> subsettedLikelihoodMatrix = alleles.size() == alleleLikelihoods.numberOfAlleles() ? sampleMatrix : new SubsettedLikelihoodMatrix<>(sampleMatrix, alleles);
        if (sampleMatrix.evidence().isEmpty()) {
            alleleFractionsPosterior = composePriorPseudoCounts;
        } else {
            RealMatrix composeInputLikelihoodMatrix = composeInputLikelihoodMatrix(alleleLikelihoods, subsettedLikelihoodMatrix);
            alleleFractionsPosterior = SomaticLikelihoodsEngine.alleleFractionsPosterior(composeInputLikelihoodMatrix, composePriorPseudoCounts, calculateWeights(composeInputLikelihoodMatrix));
        }
        double[] normalizeSumToOne = MathUtils.normalizeSumToOne(alleleFractionsPosterior);
        if (!this.keepPriorInCount) {
            for (int i = 0; i < alleleFractionsPosterior.length; i++) {
                double[] dArr = alleleFractionsPosterior;
                int i2 = i;
                dArr[i2] = dArr[i2] - composePriorPseudoCounts[i];
            }
        }
        DoubleStream stream = Arrays.stream(alleleFractionsPosterior);
        DecimalFormat decimalFormat = DEPTH_FORMAT;
        decimalFormat.getClass();
        genotypeBuilder.attribute(GATKVCFConstants.PSEUDO_DEPTH_KEY, stream.mapToObj(decimalFormat::format).collect(Collectors.joining(",")));
        DoubleStream stream2 = Arrays.stream(normalizeSumToOne);
        DecimalFormat decimalFormat2 = FRACTION_FORMAT;
        decimalFormat2.getClass();
        genotypeBuilder.attribute("DF", stream2.mapToObj(decimalFormat2::format).collect(Collectors.joining(",")));
    }

    private RealMatrix composeInputLikelihoodMatrix(AlleleLikelihoods<GATKRead, Allele> alleleLikelihoods, final LikelihoodMatrix<GATKRead, Allele> likelihoodMatrix) {
        RealMatrix asRealMatrix;
        likelihoodMatrix.asRealMatrix().copy();
        if (alleleLikelihoods.isNaturalLog()) {
            asRealMatrix = likelihoodMatrix.asRealMatrix();
        } else {
            asRealMatrix = likelihoodMatrix.asRealMatrix().copy();
            asRealMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.walkers.annotator.AllelePseudoDepth.1
                public double visit(int i, int i2, double d) {
                    return Math.max(d, (-0.1d) * ((GATKRead) likelihoodMatrix.evidence().get(i)).getMappingQuality()) * MathUtils.LOG_10;
                }
            });
        }
        return asRealMatrix;
    }

    private double[] calculateWeights(RealMatrix realMatrix) {
        if (this.weightDecay == 0.0d) {
            return null;
        }
        if (this.weightDecay < 0.0d) {
            throw new IllegalArgumentException("the weight decay must be 0 or greater");
        }
        double[] dArr = new double[realMatrix.getColumnDimension()];
        for (int i = 0; i < dArr.length; i++) {
            double entry = realMatrix.getEntry(0, i);
            double d = Double.NEGATIVE_INFINITY;
            for (int i2 = 1; i2 < realMatrix.getRowDimension(); i2++) {
                double entry2 = realMatrix.getEntry(i2, i);
                if (entry2 > entry) {
                    d = entry;
                    entry = entry2;
                } else if (entry2 > d) {
                    d = entry2;
                }
            }
            dArr[i] = 1.0d - Math.pow(10.0d, d - entry);
            if (this.weightDecay != 1.0d) {
                dArr[i] = Math.pow(dArr[i], this.weightDecay);
            }
        }
        return dArr;
    }

    private double[] composePriorPseudoCounts(int i) {
        double[] dArr = (double[]) this.priorPseudoCounts.get(i);
        if (dArr != null) {
            return dArr;
        }
        double[] dArr2 = new double[i];
        Arrays.fill(dArr2, this.prior);
        this.priorPseudoCounts.put(i, dArr2);
        return dArr2;
    }
}
