package org.broadinstitute.hellbender.tools.spark.pathseq;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.SAMSequenceRecord;
import java.io.IOException;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.programgroups.MetagenomicsProgramGroup;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSScoreFileLogger;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat;
import scala.Tuple2;

@CommandLineProgramProperties(summary = "Classify reads and estimate abundances of each taxon in the reference. This is the third and final step of the PathSeq pipeline.", oneLineSummary = "Step 3: Classifies pathogen-aligned reads and generates abundance scores", programGroup = MetagenomicsProgramGroup.class)
@DocumentedFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/pathseq/PathSeqScoreSpark.class */
public class PathSeqScoreSpark extends GATKSparkTool {
    private static final long serialVersionUID = 1;
    public static final String PAIRED_INPUT_LONG_NAME = "paired-input";
    public static final String UNPAIRED_INPUT_LONG_NAME = "unpaired-input";

    @Argument(doc = "Input queryname-sorted BAM containing only paired reads", fullName = "paired-input", optional = true)
    public String pairedInput = null;

    @Argument(doc = "Input BAM containing only unpaired reads", fullName = "unpaired-input", optional = true)
    public String unpairedInput = null;

    @Argument(doc = "Output BAM", fullName = "output", shortName = "O", optional = true)
    public String outputPath = null;

    @ArgumentCollection
    public PSScoreArgumentCollection scoreArgs = new PSScoreArgumentCollection();
    private int recommendedNumReducers = 0;

    private Tuple2<JavaRDD<GATKRead>, SAMFileHeader> readInputWithHeader(String str, ReadsSparkSource readsSparkSource) {
        if (str != null) {
            if (BucketUtils.fileExists(str)) {
                this.recommendedNumReducers += PSUtils.pathseqGetRecommendedNumReducers(str, this.numReducers, getTargetPartitionSize());
                return new Tuple2<>(PSUtils.primaryReads(readsSparkSource.getParallelReads(new GATKPath(str), null, null, this.bamPartitionSplitSize, this.useNio)), readsSparkSource.getHeader(new GATKPath(str), null));
            }
            this.logger.warn("Could not find file " + str + ". Skipping...");
        }
        return new Tuple2<>((Object) null, (Object) null);
    }

    static SAMFileHeader joinBamHeaders(SAMFileHeader sAMFileHeader, SAMFileHeader sAMFileHeader2) {
        SAMFileHeader sAMFileHeader3;
        if (sAMFileHeader != null) {
            sAMFileHeader3 = sAMFileHeader;
            if (sAMFileHeader2 != null && !sAMFileHeader3.equals(sAMFileHeader2)) {
                for (SAMSequenceRecord sAMSequenceRecord : sAMFileHeader2.getSequenceDictionary().getSequences()) {
                    if (sAMFileHeader3.getSequenceDictionary().getSequence(sAMSequenceRecord.getSequenceName()) == null) {
                        sAMFileHeader3.addSequence(sAMSequenceRecord);
                    }
                }
                for (SAMReadGroupRecord sAMReadGroupRecord : sAMFileHeader2.getReadGroups()) {
                    if (sAMFileHeader3.getReadGroup(sAMReadGroupRecord.getReadGroupId()) == null) {
                        sAMFileHeader3.addReadGroup(sAMReadGroupRecord);
                    }
                }
            }
        } else {
            if (sAMFileHeader2 == null) {
                throw new UserException.BadInput("No headers were loaded");
            }
            sAMFileHeader3 = sAMFileHeader2;
        }
        return sAMFileHeader3;
    }

    @Override // org.broadinstitute.hellbender.engine.spark.GATKSparkTool
    protected void runTool(JavaSparkContext javaSparkContext) {
        if (!this.readArguments.getReadPathSpecifiers().isEmpty()) {
            throw new UserException.BadInput("Please use --paired-input or --unpaired-input instead of --input");
        }
        ReadsSparkSource readsSparkSource = new ReadsSparkSource(javaSparkContext, this.readArguments.getReadValidationStringency());
        Tuple2<JavaRDD<GATKRead>, SAMFileHeader> readInputWithHeader = readInputWithHeader(this.pairedInput, readsSparkSource);
        Tuple2<JavaRDD<GATKRead>, SAMFileHeader> readInputWithHeader2 = readInputWithHeader(this.unpairedInput, readsSparkSource);
        JavaRDD<GATKRead> javaRDD = (JavaRDD) readInputWithHeader._1;
        SAMFileHeader sAMFileHeader = (SAMFileHeader) readInputWithHeader._2;
        JavaRDD<GATKRead> javaRDD2 = (JavaRDD) readInputWithHeader2._1;
        SAMFileHeader sAMFileHeader2 = (SAMFileHeader) readInputWithHeader2._2;
        if (sAMFileHeader != null && !sAMFileHeader.getSortOrder().equals(SAMFileHeader.SortOrder.queryname)) {
            throw new UserException.BadInput("Paired input BAM must be sorted by queryname");
        }
        SAMFileHeader joinBamHeaders = joinBamHeaders(sAMFileHeader, sAMFileHeader2);
        JavaRDD<GATKRead> scoreReads = new PSScorer(this.scoreArgs).scoreReads(javaSparkContext, javaRDD, javaRDD2, joinBamHeaders);
        if (this.scoreArgs.scoreMetricsFileUri != null) {
            PSScoreFileLogger pSScoreFileLogger = new PSScoreFileLogger(getMetricsFile(), this.scoreArgs.scoreMetricsFileUri);
            Throwable th = null;
            try {
                pSScoreFileLogger.logReadCounts(scoreReads);
                if (pSScoreFileLogger != null) {
                    if (0 != 0) {
                        try {
                            pSScoreFileLogger.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        pSScoreFileLogger.close();
                    }
                }
            } catch (Throwable th3) {
                if (pSScoreFileLogger != null) {
                    if (0 != 0) {
                        try {
                            pSScoreFileLogger.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        pSScoreFileLogger.close();
                    }
                }
                throw th3;
            }
        }
        if (this.outputPath != null) {
            try {
                ReadsSparkSink.writeReads(javaSparkContext, this.outputPath, null, scoreReads, joinBamHeaders, this.shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, this.recommendedNumReducers, this.shardedPartsDir, true, this.splittingIndexGranularity);
            } catch (IOException e) {
                throw new UserException.CouldNotCreateOutputFile(this.outputPath, "writing failed", e);
            }
        }
    }
}
