package org.broadinstitute.hellbender.tools.walkers.mutect.filtering;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.variant.variantcontext.VariantContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.tools.walkers.validation.basicshortmutpileup.BetaBinomialDistribution;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.OptimizationUtils;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/filtering/StrandArtifactFilter.class */
public class StrandArtifactFilter extends Mutect2VariantFilter {
    private static final double ALPHA_SEQ = 1.0d;
    private static final double BETA_SEQ_SNV = 1000.0d;
    private static final double BETA_SEQ_SHORT_INDEL = 5000.0d;
    private static final double BETA_SEQ_LONG_INDEL = 50000.0d;
    private static final int LONG_INDEL_SIZE = 3;
    private static final int LONGEST_STRAND_ARTIFACT_INDEL_SIZE = 4;
    private static final double INITIAL_STRAND_ARTIFACT_PRIOR = 0.001d;
    private static final double ARTIFACT_PSEUDOCOUNT = 1.0d;
    private static final double NON_ARTIFACT_PSEUDOCOUNT = 1000.0d;
    private double INITIAL_ALPHA_STRAND = 1.0d;
    private double INITIAL_BETA_STRAND = 20.0d;
    private double alphaStrand = this.INITIAL_ALPHA_STRAND;
    private double betaStrand = this.INITIAL_BETA_STRAND;
    private double strandArtifactPrior = 0.001d;
    private final List<EStep> eSteps = new ArrayList();

    /* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/filtering/StrandArtifactFilter$EStep.class */
    public static final class EStep {
        private double forwardArtifactResponsibility;
        private double reverseArtifactResponsibility;
        private int forwardCount;
        private int reverseCount;
        private int forwardAltCount;
        private int reverseAltCount;

        public EStep(double d, double d2, int i, int i2, int i3, int i4) {
            this.forwardArtifactResponsibility = d;
            this.reverseArtifactResponsibility = d2;
            this.forwardCount = i;
            this.reverseCount = i2;
            this.forwardAltCount = i3;
            this.reverseAltCount = i4;
        }

        public double getForwardArtifactResponsibility() {
            return this.forwardArtifactResponsibility;
        }

        public double getReverseArtifactResponsibility() {
            return this.reverseArtifactResponsibility;
        }

        public double getArtifactProbability() {
            return getForwardArtifactResponsibility() + getReverseArtifactResponsibility();
        }

        public int getForwardCount() {
            return this.forwardCount;
        }

        public int getReverseCount() {
            return this.reverseCount;
        }

        public int getForwardAltCount() {
            return this.forwardAltCount;
        }

        public int getReverseAltCount() {
            return this.reverseAltCount;
        }
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    public ErrorType errorType() {
        return ErrorType.ARTIFACT;
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    public double calculateErrorProbability(VariantContext variantContext, Mutect2FilteringEngine mutect2FilteringEngine, ReferenceContext referenceContext) {
        EStep calculateArtifactProbabilities = calculateArtifactProbabilities(variantContext, mutect2FilteringEngine);
        return calculateArtifactProbabilities.forwardArtifactResponsibility + calculateArtifactProbabilities.reverseArtifactResponsibility;
    }

    public EStep calculateArtifactProbabilities(VariantContext variantContext, Mutect2FilteringEngine mutect2FilteringEngine) {
        int[] sumStrandCountsOverSamples = mutect2FilteringEngine.sumStrandCountsOverSamples(variantContext, true, false);
        int abs = Math.abs(variantContext.getReference().length() - variantContext.getAlternateAllele(0).length());
        return (sumStrandCountsOverSamples[2] + sumStrandCountsOverSamples[3] == 0 || abs > 4) ? new EStep(0.0d, 0.0d, sumStrandCountsOverSamples[0] + sumStrandCountsOverSamples[2], sumStrandCountsOverSamples[1] + sumStrandCountsOverSamples[3], sumStrandCountsOverSamples[2], sumStrandCountsOverSamples[3]) : strandArtifactProbability(this.strandArtifactPrior, sumStrandCountsOverSamples[0] + sumStrandCountsOverSamples[2], sumStrandCountsOverSamples[1] + sumStrandCountsOverSamples[3], sumStrandCountsOverSamples[2], sumStrandCountsOverSamples[3], abs);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    public void accumulateDataForLearning(VariantContext variantContext, ErrorProbabilities errorProbabilities, Mutect2FilteringEngine mutect2FilteringEngine) {
        Stream<String> stream = requiredAnnotations().stream();
        variantContext.getClass();
        if (stream.allMatch(variantContext::hasAttribute)) {
            this.eSteps.add(calculateArtifactProbabilities(variantContext, mutect2FilteringEngine));
        }
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    protected void clearAccumulatedData() {
        this.eSteps.clear();
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    protected void learnParameters() {
        List list = (List) this.eSteps.stream().filter(eStep -> {
            return eStep.getArtifactProbability() > 0.1d;
        }).collect(Collectors.toList());
        double sum = list.stream().mapToDouble((v0) -> {
            return v0.getArtifactProbability();
        }).sum();
        this.strandArtifactPrior = (sum + 1.0d) / (((sum + 1.0d) + this.eSteps.stream().mapToDouble(eStep2 -> {
            return 1.0d - eStep2.getArtifactProbability();
        }).sum()) + 1000.0d);
        double sum2 = (list.stream().mapToDouble(eStep3 -> {
            return (eStep3.forwardArtifactResponsibility * eStep3.forwardAltCount) + (eStep3.reverseArtifactResponsibility * eStep3.reverseAltCount);
        }).sum() + this.INITIAL_ALPHA_STRAND) / ((list.stream().mapToDouble(eStep4 -> {
            return (eStep4.forwardArtifactResponsibility * eStep4.forwardCount) + (eStep4.reverseArtifactResponsibility * eStep4.reverseCount);
        }).sum() + this.INITIAL_ALPHA_STRAND) + this.INITIAL_BETA_STRAND);
        this.alphaStrand = OptimizationUtils.max(d -> {
            double d = ((1.0d / sum2) - 1.0d) * d;
            return list.stream().mapToDouble(eStep5 -> {
                return (eStep5.getForwardArtifactResponsibility() * artifactStrandLogLikelihood(eStep5.forwardCount, eStep5.forwardAltCount, d, d)) + (eStep5.getReverseArtifactResponsibility() * artifactStrandLogLikelihood(eStep5.reverseCount, eStep5.reverseAltCount, d, d));
            }).sum();
        }, 0.01d, 100.0d, this.INITIAL_ALPHA_STRAND, 0.01d, 0.01d, 100).getPoint();
        this.betaStrand = ((1.0d / sum2) - 1.0d) * this.alphaStrand;
        this.eSteps.clear();
    }

    @VisibleForTesting
    EStep strandArtifactProbability(double d, int i, int i2, int i3, int i4, int i5) {
        double[] normalizeLog10 = MathUtils.normalizeLog10(new double[]{(artifactStrandLogLikelihood(i, i3) + nonArtifactStrandLogLikelihood(i2, i4, i5) + Math.log(d / 2.0d)) * MathUtils.LOG10_E, (artifactStrandLogLikelihood(i2, i4) + nonArtifactStrandLogLikelihood(i, i3, i5) + Math.log(d / 2.0d)) * MathUtils.LOG10_E, (((CombinatoricsUtils.binomialCoefficientLog(i, i3) + CombinatoricsUtils.binomialCoefficientLog(i2, i4)) - CombinatoricsUtils.binomialCoefficientLog(i + i2, i3 + i4)) + new BetaBinomialDistribution(null, 1.0d, 1.0d, i + i2).logProbability(i3 + i4) + Math.log(1.0d - d)) * MathUtils.LOG10_E}, false, true);
        return new EStep(normalizeLog10[0], normalizeLog10[1], i, i2, i3, i4);
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    public String filterName() {
        return GATKVCFConstants.STRAND_ARTIFACT_FILTER_NAME;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    public List<String> requiredAnnotations() {
        return Collections.emptyList();
    }

    private double artifactStrandLogLikelihood(int i, int i2) {
        return artifactStrandLogLikelihood(i, i2, this.alphaStrand, this.betaStrand);
    }

    private static double artifactStrandLogLikelihood(int i, int i2, double d, double d2) {
        return new BetaBinomialDistribution(null, d, d2, i).logProbability(i2);
    }

    private double nonArtifactStrandLogLikelihood(int i, int i2, int i3) {
        return new BetaBinomialDistribution(null, 1.0d, i3 == 0 ? 1000.0d : i3 < 3 ? BETA_SEQ_SHORT_INDEL : BETA_SEQ_LONG_INDEL, i).logProbability(i2);
    }

    @Override // org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter
    public Optional<String> phredScaledPosteriorAnnotationName() {
        return Optional.of(GATKVCFConstants.STRAND_QUAL_KEY);
    }
}
