package org.broadinstitute.hellbender.tools.walkers.mutect.clustering;

import com.google.common.primitives.Doubles;
import htsjdk.variant.variantcontext.VariantContext;
import it.unimi.dsi.fastutil.ints.Int2DoubleArrayMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.util.MathArrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.tools.walkers.mutect.MutectStats;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.M2FiltersArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine;
import org.broadinstitute.hellbender.tools.walkers.readorientation.BetaDistributionShape;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Utils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/clustering/SomaticClusteringModel.class */
public class SomaticClusteringModel {
    private boolean clustersHaveBeenInitialized;
    private static final int MAX_INDEL_SIZE_IN_PRIOR_MAP = 10;
    private static final int NUM_INITIALIZATION_QUANTILES = 50;
    private static final double MIN_QUANTILE_FOR_MAKING_CLUSTER = 0.1d;
    private static final int MIN_QUANTILE_INDEX_FOR_MAKING_CLUSTER = 5;
    private double logVariantVsArtifactPrior;
    private final OptionalDouble callableSites;
    private static final double INITIAL_HIGH_AF_WEIGHT = 0.01d;
    public static final double MAX_FRACTION_OF_BACKGROUND_TO_SPLIT_OFF = 0.9d;
    private double[] logClusterWeights;
    private static final int NUM_ITERATIONS = 5;
    private static final int MAX_BINOMIAL_CLUSTERS = 5;
    private static final BetaDistributionShape INITIAL_HIGH_AF_BETA = new BetaDistributionShape(10.0d, 1.0d);
    private static final BetaDistributionShape INITIAL_BACKGROUND_BETA = BetaDistributionShape.FLAT_BETA;
    private static final double OBVIOUS_ARTIFACT_PROBABILITY_THRESHOLD = 0.9d;
    protected final Logger logger = LogManager.getLogger(getClass());
    private final Map<Integer, Double> logVariantPriors = new Int2DoubleArrayMap();
    private double REGULARIZING_PSEUDOCOUNT = 1.0d;
    final List<Datum> data = new ArrayList();
    final List<AlleleFractionCluster> clusters = new ArrayList();
    private final MutableInt obviousArtifactCount = new MutableInt(0);

    public SomaticClusteringModel(M2FiltersArgumentCollection m2FiltersArgumentCollection, List<MutectStats> list) {
        IntStream.range(-10, 11).forEach(i -> {
            this.logVariantPriors.put(Integer.valueOf(i), Double.valueOf(m2FiltersArgumentCollection.getLogIndelPrior()));
        });
        this.logVariantPriors.put(0, Double.valueOf(m2FiltersArgumentCollection.getLogSnvPrior()));
        this.logVariantVsArtifactPrior = m2FiltersArgumentCollection.initialLogPriorOfVariantVersusArtifact;
        OptionalDouble findFirst = list.stream().filter(mutectStats -> {
            return mutectStats.getStatistic().equals(Mutect2Engine.CALLABLE_SITES_NAME);
        }).mapToDouble((v0) -> {
            return v0.getValue();
        }).findFirst();
        boolean z = findFirst.isPresent() && findFirst.getAsDouble() < 1.0d;
        if (z) {
            this.logger.warn("No callable sites found in Mutect stats.  Running without the full somatic clustering model.  Something is seriously wrong!");
        }
        this.callableSites = z ? OptionalDouble.empty() : findFirst;
        this.clusters.add(new BetaBinomialCluster(INITIAL_BACKGROUND_BETA));
        this.clusters.add(new BetaBinomialCluster(INITIAL_HIGH_AF_BETA));
        this.logClusterWeights = new double[]{Math.log1p(INITIAL_HIGH_AF_WEIGHT), Math.log(INITIAL_HIGH_AF_WEIGHT)};
    }

    public void record(int[] iArr, double[] dArr, List<Double> list, List<Double> list2, VariantContext variantContext) {
        new IndexRange(0, variantContext.getNAlleles() - 1).filter(i -> {
            return variantContext.getAlternateAllele(i).isSymbolic();
        }).forEach(num -> {
            iArr[num.intValue()] = 0;
        });
        int sum = (int) MathUtils.sum(iArr);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (!variantContext.getAlternateAllele(i2).isSymbolic()) {
                if (list.get(i2).doubleValue() > 0.9d) {
                    this.obviousArtifactCount.increment();
                } else if (list2.get(i2).doubleValue() <= 0.9d) {
                    this.data.add(new Datum(dArr[i2], list.get(i2).doubleValue(), list2.get(i2).doubleValue(), iArr[i2 + 1], sum, indelLength(variantContext, i2)));
                }
            }
        }
    }

    public double getLogPriorOfSomaticVariant(VariantContext variantContext, int i) {
        return getLogPriorOfSomaticVariant(indelLength(variantContext, i));
    }

    public double getLogPriorOfVariantVersusArtifact() {
        return this.logVariantVsArtifactPrior;
    }

    public double probabilityOfSequencingError(Datum datum) {
        return Mutect2FilteringEngine.posteriorProbabilityOfError(NaturalLogUtils.logSumExp(new IndexRange(0, this.clusters.size()).mapToDouble(i -> {
            return this.logClusterWeights[i] + this.clusters.get(i).correctedLogLikelihood(datum);
        })), getLogPriorOfSomaticVariant(datum.getIndelLength()));
    }

    private double probabilityOfSomaticVariant(Datum datum) {
        return (1.0d - datum.getArtifactProb()) * (1.0d - datum.getNonSequencingErrorProb()) * (1.0d - probabilityOfSequencingError(datum));
    }

    private void initializeClusters() {
        Utils.validate(!this.clustersHaveBeenInitialized, "Clusters have already been initialized.");
        double[] array = this.data.stream().mapToDouble(this::probabilityOfSomaticVariant).toArray();
        double d = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (true) {
            if (i >= 5) {
                break;
            }
            double[] copyOf = Arrays.copyOf(this.logClusterWeights, this.logClusterWeights.length);
            double[] ebeMultiply = MathArrays.ebeMultiply(array, this.data.stream().mapToDouble(datum -> {
                return backgroundProbGivenSomatic(datum.getTotalCount(), datum.getAltCount());
            }).toArray());
            double[] calculateAlleleFractionQuantiles = calculateAlleleFractionQuantiles();
            List<Pair<Double, Double>> calculatePeaksAndMasses = calculatePeaksAndMasses(calculateAlleleFractionQuantiles, calculateQuantileBackgroundResponsibilities(calculateAlleleFractionQuantiles, ebeMultiply));
            if (calculatePeaksAndMasses.isEmpty()) {
                break;
            }
            Pair<Double, Double> pair = calculatePeaksAndMasses.stream().sorted(Comparator.comparingDouble((v0) -> {
                return v0.getRight();
            }).reversed()).findFirst().get();
            if (((Double) pair.getLeft()).doubleValue() < calculateAlleleFractionQuantiles[Math.min(5, calculateAlleleFractionQuantiles.length - 1)]) {
                break;
            }
            double min = Math.min(0.9d, ((Double) pair.getRight()).doubleValue() / calculatePeaksAndMasses.stream().mapToDouble((v0) -> {
                return v0.getRight();
            }).sum());
            double log = Math.log(min) + this.logClusterWeights[0];
            double log1p = Math.log1p(min) + this.logClusterWeights[0];
            this.clusters.add(new BinomialCluster(((Double) pair.getLeft()).doubleValue()));
            ArrayList arrayList = new ArrayList(Doubles.asList(this.logClusterWeights));
            arrayList.add(Double.valueOf(log));
            arrayList.set(0, Double.valueOf(log1p));
            this.logClusterWeights = Doubles.toArray(arrayList);
            for (int i2 = 0; i2 < 5; i2++) {
                performEMIteration(false);
            }
            double sum = MathUtils.sum(MathArrays.ebeMultiply(array, this.data.stream().mapToDouble(datum2 -> {
                return logLikelihoodGivenSomatic(datum2.getTotalCount(), datum2.getAltCount());
            }).toArray())) - ((2 * this.clusters.size()) * Math.log(MathUtils.sum(array)));
            if (sum < d) {
                this.clusters.remove(this.clusters.size() - 1);
                this.logClusterWeights = copyOf;
                break;
            } else {
                d = sum;
                i++;
            }
        }
        this.clustersHaveBeenInitialized = true;
    }

    private double[] calculateAlleleFractionQuantiles() {
        List list = (List) this.data.stream().map(datum -> {
            return ImmutablePair.of(Double.valueOf(datum.getAltCount() / datum.getTotalCount()), Double.valueOf(probabilityOfSomaticVariant(datum)));
        }).sorted(Comparator.comparingDouble(immutablePair -> {
            return ((Double) immutablePair.getLeft()).doubleValue();
        })).collect(Collectors.toList());
        double d = 0.0d;
        double sum = list.stream().mapToDouble(pair -> {
            return ((Double) pair.getRight()).doubleValue();
        }).sum() / 50.0d;
        double d2 = sum;
        ArrayList arrayList = new ArrayList(50);
        for (int i = 0; i < this.data.size(); i++) {
            d += ((Double) ((Pair) list.get(i)).getRight()).doubleValue();
            if (d > d2) {
                arrayList.add(((Pair) list.get(i)).getLeft());
                while (d > d2) {
                    d2 += sum;
                }
            }
        }
        return Doubles.toArray((Collection) arrayList.stream().distinct().collect(Collectors.toList()));
    }

    private double[] calculateQuantileBackgroundResponsibilities(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < this.data.size(); i++) {
            Datum datum = this.data.get(i);
            double d = dArr2[i];
            double[] applyToArray = MathUtils.applyToArray(dArr, d2 -> {
                return MathUtils.binomialProbability(datum.getTotalCount(), datum.getAltCount(), d2);
            });
            MathUtils.applyToArrayInPlace(applyToArray, d3 -> {
                return d3 * d * (datum.getTotalCount() + 1);
            });
            MathUtils.addToArrayInPlace(dArr3, applyToArray);
        }
        return dArr3;
    }

    private List<Pair<Double, Double>> calculatePeaksAndMasses(double[] dArr, double[] dArr2) {
        ArrayList arrayList = new ArrayList();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        int i = 0;
        while (i < dArr.length) {
            double d4 = i == 0 ? 0.0d : dArr2[i - 1];
            double d5 = dArr2[i];
            double d6 = i == dArr.length - 1 ? 0.0d : dArr2[i + 1];
            double d7 = i == 0 ? 0.0d : dArr[i - 1];
            double d8 = dArr[i];
            d += ((d8 - d7) * (d4 + d5)) / 2.0d;
            if (d5 > d3) {
                d2 = d8;
                d3 = d5;
            }
            int compare = Double.compare(d5, d4);
            int compare2 = Double.compare(d5, d6);
            if ((((compare < 0 && compare2 <= 0) || (compare <= 0 && compare2 < 0)) && i > 0) || i == dArr.length - 1) {
                arrayList.add(ImmutablePair.of(Double.valueOf(d2), Double.valueOf(d)));
                d = 0.0d;
                d2 = d8;
                d3 = d5;
            }
            i++;
        }
        return arrayList;
    }

    private double getLogPriorOfSomaticVariant(int i) {
        if (!this.logVariantPriors.containsKey(Integer.valueOf(i))) {
            this.logVariantPriors.put(Integer.valueOf(i), Double.valueOf(this.logVariantPriors.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).min().getAsDouble()));
        }
        return this.logVariantPriors.get(Integer.valueOf(i)).doubleValue() + (i == 0 ? MathUtils.LOG_ONE_THIRD : 0.0d);
    }

    public void learnAndClearAccumulatedData() {
        if (!this.clustersHaveBeenInitialized) {
            initializeClusters();
        }
        for (int i = 0; i < 5; i++) {
            performEMIteration(true);
        }
        this.data.clear();
        this.obviousArtifactCount.setValue(0);
    }

    private void performEMIteration(boolean z) {
        Map map = (Map) IntStream.range(-10, 11).boxed().collect(Collectors.toMap(num -> {
            return num;
        }, num2 -> {
            return new MutableDouble(0.0d);
        }));
        ArrayList arrayList = new ArrayList(this.data.size());
        double[] dArr = new double[this.clusters.size()];
        for (Datum datum : this.data) {
            double probabilityOfSomaticVariant = probabilityOfSomaticVariant(datum);
            int indelLength = datum.getIndelLength();
            map.putIfAbsent(Integer.valueOf(indelLength), new MutableDouble(0.0d));
            ((MutableDouble) map.get(Integer.valueOf(indelLength))).add(probabilityOfSomaticVariant);
            double[] scale = MathArrays.scale(probabilityOfSomaticVariant, NaturalLogUtils.normalizeFromLogToLinearSpace(new IndexRange(0, this.clusters.size()).mapToDouble(i -> {
                return this.logClusterWeights[i] + this.clusters.get(i).logLikelihood(datum.getTotalCount(), datum.getAltCount());
            })));
            MathUtils.addToArrayInPlace(dArr, scale);
            arrayList.add(scale);
        }
        MathUtils.applyToArrayInPlace(dArr, d -> {
            return d + this.REGULARIZING_PSEUDOCOUNT;
        });
        this.logClusterWeights = MathUtils.applyToArrayInPlace(MathUtils.normalizeSumToOne(dArr), Math::log);
        double intValue = this.obviousArtifactCount.getValue().intValue() + this.data.stream().mapToDouble((v0) -> {
            return v0.getArtifactProb();
        }).sum();
        double sum = map.values().stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).sum();
        if (z) {
            this.logVariantVsArtifactPrior = Math.log((sum + this.REGULARIZING_PSEUDOCOUNT) / ((sum + intValue) + (this.REGULARIZING_PSEUDOCOUNT * 2.0d)));
            if (this.callableSites.isPresent()) {
                IntStream.range(-10, 11).forEach(i2 -> {
                    this.logVariantPriors.put(Integer.valueOf(i2), Double.valueOf(Math.log(Math.max(((MutableDouble) map.getOrDefault(Integer.valueOf(i2), new MutableDouble(0.0d))).doubleValue() / this.callableSites.getAsDouble(), i2 == 0 ? 1.0E-8d : 1.0E-9d))));
                });
            }
        }
        new IndexRange(0, this.clusters.size()).forEach(i3 -> {
            this.clusters.get(i3).learn(this.data, arrayList.stream().mapToDouble(dArr2 -> {
                return dArr2[i3];
            }).toArray());
        });
    }

    public double logLikelihoodGivenSomatic(int i, int i2) {
        return NaturalLogUtils.logSumExp(new IndexRange(0, this.clusters.size()).mapToDouble(i3 -> {
            return this.logClusterWeights[i3] + this.clusters.get(i3).logLikelihood(i, i2);
        }));
    }

    private double backgroundProbGivenSomatic(int i, int i2) {
        return NaturalLogUtils.normalizeFromLogToLinearSpace(new IndexRange(0, this.clusters.size()).mapToDouble(i3 -> {
            return this.logClusterWeights[i3] + this.clusters.get(i3).logLikelihood(i, i2);
        }))[0];
    }

    public List<Pair<String, String>> clusteringMetadata() {
        ArrayList arrayList = new ArrayList();
        IntStream.range(-10, 11).forEach(i -> {
            String str;
            double doubleValue = this.logVariantPriors.get(Integer.valueOf(i)).doubleValue();
            if (i == 0) {
                str = "SNV";
            } else {
                str = (i < 0 ? "deletion" : "insertion") + " of length " + Math.abs(i);
            }
            arrayList.add(ImmutablePair.of("Ln prior of " + str, Double.toString(doubleValue)));
        });
        arrayList.add(ImmutablePair.of("Background beta-binomial cluster", String.format("weight = %.4f, %s", Double.valueOf(Math.exp(this.logClusterWeights[0])), this.clusters.get(0).toString())));
        arrayList.add(ImmutablePair.of("High-AF beta-binomial cluster", String.format("weight = %.4f, %s", Double.valueOf(Math.exp(this.logClusterWeights[1])), this.clusters.get(1).toString())));
        IntStream.range(2, this.clusters.size()).boxed().sorted(Comparator.comparingDouble(num -> {
            return -this.logClusterWeights[num.intValue()];
        })).forEach(num2 -> {
            arrayList.add(ImmutablePair.of("Binomial cluster", String.format("weight = %.4f, %s", Double.valueOf(Math.exp(this.logClusterWeights[num2.intValue()])), this.clusters.get(num2.intValue()).toString())));
        });
        return arrayList;
    }

    public static int indelLength(VariantContext variantContext, int i) {
        return variantContext.getAlternateAllele(i).length() - variantContext.getReference().length();
    }
}
