package org.broadinstitute.hellbender.tools.copynumber.models;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.OverlapDetector;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AllelicCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyRatioCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.ModeledSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SimpleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CopyRatio;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.ModeledSegment;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/models/MultidimensionalModeller.class */
public final class MultidimensionalModeller {
    private static final Logger logger = LogManager.getLogger(MultidimensionalModeller.class);
    private final SampleLocatableMetadata metadata;
    private final CopyRatioCollection denoisedCopyRatios;
    private final OverlapDetector<CopyRatio> copyRatioMidpointOverlapDetector;
    private final AllelicCountCollection allelicCounts;
    private final OverlapDetector<AllelicCount> allelicCountOverlapDetector;
    private final AlleleFractionPrior alleleFractionPrior;
    private CopyRatioModeller copyRatioModeller;
    private AlleleFractionModeller alleleFractionModeller;
    private SimpleIntervalCollection currentSegments;
    private final List<ModeledSegment> modeledSegments = new ArrayList();
    private boolean isModelFit;
    private final int numSamplesCopyRatio;
    private final int numBurnInCopyRatio;
    private final int numSamplesAlleleFraction;
    private final int numBurnInAlleleFraction;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/models/MultidimensionalModeller$SimilarSegmentUtils.class */
    public static final class SimilarSegmentUtils {
        private SimilarSegmentUtils() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static List<ModeledSegment> mergeSimilarSegments(List<ModeledSegment> list, double d, double d2) {
            ArrayList arrayList = new ArrayList(list);
            int i = 0;
            while (i < arrayList.size() - 1) {
                ModeledSegment modeledSegment = (ModeledSegment) arrayList.get(i);
                ModeledSegment modeledSegment2 = (ModeledSegment) arrayList.get(i + 1);
                if (modeledSegment.getContig().equals(modeledSegment2.getContig()) && areSimilar(modeledSegment, modeledSegment2, d, d2)) {
                    arrayList.set(i, merge(modeledSegment, modeledSegment2));
                    arrayList.remove(i + 1);
                    i--;
                }
                i++;
            }
            return arrayList;
        }

        private static boolean areSimilar(ModeledSegment.SimplePosteriorSummary simplePosteriorSummary, ModeledSegment.SimplePosteriorSummary simplePosteriorSummary2, double d) {
            if (Double.isNaN(simplePosteriorSummary.getDecile50()) || Double.isNaN(simplePosteriorSummary2.getDecile50())) {
                return true;
            }
            double abs = Math.abs(simplePosteriorSummary.getDecile50() - simplePosteriorSummary2.getDecile50());
            return abs < d * (simplePosteriorSummary.getDecile90() - simplePosteriorSummary.getDecile10()) || abs < d * (simplePosteriorSummary2.getDecile90() - simplePosteriorSummary2.getDecile10());
        }

        private static boolean areSimilar(ModeledSegment modeledSegment, ModeledSegment modeledSegment2, double d, double d2) {
            return areSimilar(modeledSegment.getLog2CopyRatioSimplePosteriorSummary(), modeledSegment2.getLog2CopyRatioSimplePosteriorSummary(), d) && areSimilar(modeledSegment.getMinorAlleleFractionSimplePosteriorSummary(), modeledSegment2.getMinorAlleleFractionSimplePosteriorSummary(), d2);
        }

        private static ModeledSegment.SimplePosteriorSummary merge(ModeledSegment.SimplePosteriorSummary simplePosteriorSummary, ModeledSegment.SimplePosteriorSummary simplePosteriorSummary2) {
            if (Double.isNaN(simplePosteriorSummary.getDecile50()) && !Double.isNaN(simplePosteriorSummary2.getDecile50())) {
                return simplePosteriorSummary2;
            }
            if ((!Double.isNaN(simplePosteriorSummary.getDecile50()) && Double.isNaN(simplePosteriorSummary2.getDecile50())) || (Double.isNaN(simplePosteriorSummary.getDecile50()) && Double.isNaN(simplePosteriorSummary2.getDecile50()))) {
                return simplePosteriorSummary;
            }
            double decile90 = (simplePosteriorSummary.getDecile90() - simplePosteriorSummary.getDecile10()) / 2.0d;
            double decile902 = (simplePosteriorSummary2.getDecile90() - simplePosteriorSummary2.getDecile10()) / 2.0d;
            double pow = 1.0d / ((1.0d / Math.pow(decile90, 2.0d)) + (1.0d / Math.pow(decile902, 2.0d)));
            double decile50 = ((simplePosteriorSummary.getDecile50() / Math.pow(decile90, 2.0d)) + (simplePosteriorSummary2.getDecile50() / Math.pow(decile902, 2.0d))) * pow;
            double sqrt = Math.sqrt(pow);
            return new ModeledSegment.SimplePosteriorSummary(decile50, decile50 - sqrt, decile50 + sqrt);
        }

        private static ModeledSegment merge(ModeledSegment modeledSegment, ModeledSegment modeledSegment2) {
            return new ModeledSegment(mergeSegments(modeledSegment.getInterval(), modeledSegment2.getInterval()), modeledSegment.getNumPointsCopyRatio() + modeledSegment2.getNumPointsCopyRatio(), modeledSegment.getNumPointsAlleleFraction() + modeledSegment2.getNumPointsAlleleFraction(), merge(modeledSegment.getLog2CopyRatioSimplePosteriorSummary(), modeledSegment2.getLog2CopyRatioSimplePosteriorSummary()), merge(modeledSegment.getMinorAlleleFractionSimplePosteriorSummary(), modeledSegment2.getMinorAlleleFractionSimplePosteriorSummary()));
        }

        private static SimpleInterval mergeSegments(SimpleInterval simpleInterval, SimpleInterval simpleInterval2) {
            Utils.validateArg(simpleInterval.getContig().equals(simpleInterval2.getContig()), String.format("Cannot join segments %s and %s on different chromosomes.", simpleInterval.toString(), simpleInterval2.toString()));
            return new SimpleInterval(simpleInterval.getContig(), Math.min(simpleInterval.getStart(), simpleInterval2.getStart()), Math.max(simpleInterval.getEnd(), simpleInterval2.getEnd()));
        }
    }

    public MultidimensionalModeller(SimpleIntervalCollection simpleIntervalCollection, CopyRatioCollection copyRatioCollection, AllelicCountCollection allelicCountCollection, AlleleFractionPrior alleleFractionPrior, int i, int i2, int i3, int i4) {
        Utils.nonNull(simpleIntervalCollection);
        Utils.nonNull(copyRatioCollection);
        Utils.nonNull(allelicCountCollection);
        Utils.nonNull(alleleFractionPrior);
        ParamUtils.isPositiveOrZero(i2, "Number of burn-in copy-ratio samples must be non-negative.");
        Utils.validateArg(i2 < i, "Number of copy-ratio samples must be greater than number of burn-in copy-ratio samples.");
        ParamUtils.isPositiveOrZero(i4, "Number of burn-in allele-fraction samples must be non-negative.");
        Utils.validateArg(i4 < i3, "Number of allele-fraction samples must be greater than number of burn-in allele-fraction samples.");
        this.metadata = (SampleLocatableMetadata) CopyNumberArgumentValidationUtils.getValidatedMetadata(copyRatioCollection, allelicCountCollection);
        CopyNumberArgumentValidationUtils.getValidatedSequenceDictionary(simpleIntervalCollection, copyRatioCollection, allelicCountCollection);
        ParamUtils.isPositive(simpleIntervalCollection.size(), "Number of segments must be positive.");
        this.currentSegments = simpleIntervalCollection;
        this.denoisedCopyRatios = copyRatioCollection;
        this.copyRatioMidpointOverlapDetector = copyRatioCollection.getMidpointOverlapDetector();
        this.allelicCounts = allelicCountCollection;
        this.allelicCountOverlapDetector = allelicCountCollection.getOverlapDetector();
        this.alleleFractionPrior = (AlleleFractionPrior) Utils.nonNull(alleleFractionPrior);
        this.numSamplesCopyRatio = i;
        this.numBurnInCopyRatio = i2;
        this.numSamplesAlleleFraction = i3;
        this.numBurnInAlleleFraction = i4;
        logger.info("Fitting initial model...");
        fitModel();
    }

    public ModeledSegmentCollection getModeledSegments() {
        return new ModeledSegmentCollection(this.metadata, this.modeledSegments);
    }

    private void fitModel() {
        logger.info("Fitting copy-ratio model...");
        this.copyRatioModeller = new CopyRatioModeller(this.denoisedCopyRatios, this.currentSegments);
        this.copyRatioModeller.fitMCMC(this.numSamplesCopyRatio, this.numBurnInCopyRatio);
        logger.info("Fitting allele-fraction model...");
        this.alleleFractionModeller = new AlleleFractionModeller(this.allelicCounts, this.currentSegments, this.alleleFractionPrior);
        this.alleleFractionModeller.fitMCMC(this.numSamplesAlleleFraction, this.numBurnInAlleleFraction);
        this.modeledSegments.clear();
        List<ModeledSegment.SimplePosteriorSummary> segmentMeansPosteriorSummaries = this.copyRatioModeller.getSegmentMeansPosteriorSummaries();
        List<ModeledSegment.SimplePosteriorSummary> minorAlleleFractionsPosteriorSummaries = this.alleleFractionModeller.getMinorAlleleFractionsPosteriorSummaries();
        for (int i = 0; i < this.currentSegments.size(); i++) {
            SimpleInterval simpleInterval = (SimpleInterval) this.currentSegments.getRecords().get(i);
            this.modeledSegments.add(new ModeledSegment(simpleInterval, this.copyRatioMidpointOverlapDetector.getOverlaps(simpleInterval).size(), this.allelicCountOverlapDetector.getOverlaps(simpleInterval).size(), segmentMeansPosteriorSummaries.get(i), minorAlleleFractionsPosteriorSummaries.get(i)));
        }
        this.isModelFit = true;
    }

    public void smoothSegments(int i, int i2, double d, double d2) {
        ParamUtils.isPositiveOrZero(i, "The maximum number of smoothing iterations must be non-negative.");
        ParamUtils.isPositiveOrZero(d, "The number of smoothing iterations per fit must be non-negative.");
        ParamUtils.isPositiveOrZero(d2, "The allele-fraction credible-interval threshold for segmentation smoothing must be non-negative.");
        logger.info(String.format("Initial number of segments before smoothing: %d", Integer.valueOf(this.modeledSegments.size())));
        for (int i3 = 1; i3 <= i; i3++) {
            logger.info(String.format("Smoothing iteration: %d", Integer.valueOf(i3)));
            int size = this.modeledSegments.size();
            if (i2 <= 0 || i3 % i2 != 0) {
                performSmoothingIteration(d, d2, false);
            } else {
                performSmoothingIteration(d, d2, true);
            }
            if (this.modeledSegments.size() == size) {
                break;
            }
        }
        if (!this.isModelFit) {
            fitModel();
        }
        logger.info(String.format("Final number of segments after smoothing: %d", Integer.valueOf(this.modeledSegments.size())));
    }

    private void performSmoothingIteration(double d, double d2, boolean z) {
        logger.info("Number of segments before smoothing iteration: " + this.modeledSegments.size());
        List mergeSimilarSegments = SimilarSegmentUtils.mergeSimilarSegments(this.modeledSegments, d, d2);
        logger.info("Number of segments after smoothing iteration: " + mergeSimilarSegments.size());
        this.currentSegments = new SimpleIntervalCollection(new SimpleLocatableMetadata(this.metadata.getSequenceDictionary()), (List) mergeSimilarSegments.stream().map((v0) -> {
            return v0.getInterval();
        }).collect(Collectors.toList()));
        if (z) {
            fitModel();
            return;
        }
        this.modeledSegments.clear();
        this.modeledSegments.addAll(mergeSimilarSegments);
        this.isModelFit = false;
    }

    public void writeModelParameterFiles(File file, File file2) {
        Utils.nonNull(file);
        Utils.nonNull(file2);
        ensureModelIsFit();
        logger.info(String.format("Writing posterior summaries for copy-ratio global parameters to %s...", file.getAbsolutePath()));
        this.copyRatioModeller.getGlobalParameterDeciles().write(file);
        logger.info(String.format("Writing posterior summaries for allele-fraction global parameters to %s...", file2.getAbsolutePath()));
        this.alleleFractionModeller.getGlobalParameterDeciles().write(file2);
    }

    @VisibleForTesting
    CopyRatioModeller getCopyRatioModeller() {
        return this.copyRatioModeller;
    }

    @VisibleForTesting
    AlleleFractionModeller getAlleleFractionModeller() {
        return this.alleleFractionModeller;
    }

    private void ensureModelIsFit() {
        if (this.isModelFit) {
            return;
        }
        logger.warn("Attempted to write results to file when model was not completely fit. Performing model fit now.");
        fitModel();
    }
}
