/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.sort;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.internal.config.package$;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
import org.apache.spark.shuffle.sort.ShuffleExternalSorter;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.shuffle.sort.SpillInfo;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sparkproject.guava.annotations.VisibleForTesting;
import org.sparkproject.guava.io.ByteStreams;
import org.sparkproject.guava.io.Closeables;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;
import scala.collection.JavaConverters;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

@Private
public class UnsafeShuffleWriter<K, V>
extends ShuffleWriter<K, V> {
    private static final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
    private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
    @VisibleForTesting
    static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 0x100000;
    private final BlockManager blockManager;
    private final TaskMemoryManager memoryManager;
    private final SerializerInstance serializer;
    private final Partitioner partitioner;
    private final ShuffleWriteMetricsReporter writeMetrics;
    private final ShuffleExecutorComponents shuffleExecutorComponents;
    private final int shuffleId;
    private final long mapId;
    private final TaskContext taskContext;
    private final SparkConf sparkConf;
    private final boolean transferToEnabled;
    private final int initialSortBufferSize;
    private final int inputBufferSizeInBytes;
    @Nullable
    private MapStatus mapStatus;
    @Nullable
    private ShuffleExternalSorter sorter;
    private long peakMemoryUsedBytes = 0L;
    private MyByteArrayOutputStream serBuffer;
    private SerializationStream serOutputStream;
    private boolean stopping = false;

    public UnsafeShuffleWriter(BlockManager blockManager, TaskMemoryManager memoryManager, SerializedShuffleHandle<K, V> handle, long mapId, TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter writeMetrics, ShuffleExecutorComponents shuffleExecutorComponents) {
        int numPartitions = handle.dependency().partitioner().numPartitions();
        if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
            throw new IllegalArgumentException("UnsafeShuffleWriter can only be used for shuffles with at most " + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
        }
        this.blockManager = blockManager;
        this.memoryManager = memoryManager;
        this.mapId = mapId;
        ShuffleDependency dep = handle.dependency();
        this.shuffleId = dep.shuffleId();
        this.serializer = dep.serializer().newInstance();
        this.partitioner = dep.partitioner();
        this.writeMetrics = writeMetrics;
        this.shuffleExecutorComponents = shuffleExecutorComponents;
        this.taskContext = taskContext;
        this.sparkConf = sparkConf;
        this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
        this.initialSortBufferSize = (int)((Long)sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE())).longValue();
        this.inputBufferSizeInBytes = (int)((Long)sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE())).longValue() * 1024;
        this.open();
    }

    private void updatePeakMemoryUsed() {
        long mem;
        if (this.sorter != null && (mem = this.sorter.getPeakMemoryUsedBytes()) > this.peakMemoryUsedBytes) {
            this.peakMemoryUsedBytes = mem;
        }
    }

    public long getPeakMemoryUsedBytes() {
        this.updatePeakMemoryUsed();
        return this.peakMemoryUsedBytes;
    }

    @Override
    @VisibleForTesting
    public void write(java.util.Iterator<Product2<K, V>> records) throws IOException {
        this.write((Iterator)JavaConverters.asScalaIteratorConverter(records).asScala());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void write(Iterator<Product2<K, V>> records) throws IOException {
        boolean success = false;
        try {
            while (records.hasNext()) {
                this.insertRecordIntoSorter((Product2)records.next());
            }
            this.closeAndWriteOutput();
            success = true;
        }
        finally {
            if (this.sorter != null) {
                try {
                    this.sorter.cleanupResources();
                }
                catch (Exception e) {
                    if (success) {
                        throw e;
                    }
                    logger.error("In addition to a failure during writing, we failed during cleanup.", (Throwable)e);
                }
            }
        }
    }

    private void open() {
        assert (this.sorter == null);
        this.sorter = new ShuffleExternalSorter(this.memoryManager, this.blockManager, this.taskContext, this.initialSortBufferSize, this.partitioner.numPartitions(), this.sparkConf, this.writeMetrics);
        this.serBuffer = new MyByteArrayOutputStream(0x100000);
        this.serOutputStream = this.serializer.serializeStream(this.serBuffer);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    void closeAndWriteOutput() throws IOException {
        long[] partitionLengths;
        assert (this.sorter != null);
        this.updatePeakMemoryUsed();
        this.serBuffer = null;
        this.serOutputStream = null;
        SpillInfo[] spills = this.sorter.closeAndGetSpills();
        this.sorter = null;
        try {
            partitionLengths = this.mergeSpills(spills);
        }
        catch (Throwable throwable) {
            for (SpillInfo spill : spills) {
                if (!spill.file.exists() || spill.file.delete()) continue;
                logger.error("Error while deleting spill file {}", (Object)spill.file.getPath());
            }
            throw throwable;
        }
        for (SpillInfo spill : spills) {
            if (!spill.file.exists() || spill.file.delete()) continue;
            logger.error("Error while deleting spill file {}", (Object)spill.file.getPath());
        }
        this.mapStatus = MapStatus$.MODULE$.apply(this.blockManager.shuffleServerId(), partitionLengths, this.mapId);
    }

    @VisibleForTesting
    void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
        assert (this.sorter != null);
        Object key = record._1();
        int partitionId = this.partitioner.getPartition(key);
        this.serBuffer.reset();
        this.serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
        this.serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
        this.serOutputStream.flush();
        int serializedRecordSize = this.serBuffer.size();
        assert (serializedRecordSize > 0);
        this.sorter.insertRecord(this.serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
    }

    @VisibleForTesting
    void forceSorterToSpill() throws IOException {
        assert (this.sorter != null);
        this.sorter.spill();
    }

    private long[] mergeSpills(SpillInfo[] spills) throws IOException {
        long[] partitionLengths;
        if (spills.length == 0) {
            ShuffleMapOutputWriter mapWriter = this.shuffleExecutorComponents.createMapOutputWriter(this.shuffleId, this.mapId, this.partitioner.numPartitions());
            return mapWriter.commitAllPartitions();
        }
        if (spills.length == 1) {
            Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter = this.shuffleExecutorComponents.createSingleFileMapOutputWriter(this.shuffleId, this.mapId);
            if (maybeSingleFileWriter.isPresent()) {
                partitionLengths = spills[0].partitionLengths;
                maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
            } else {
                partitionLengths = this.mergeSpillsUsingStandardWriter(spills);
            }
        } else {
            partitionLengths = this.mergeSpillsUsingStandardWriter(spills);
        }
        return partitionLengths;
    }

    private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException {
        long[] partitionLengths;
        boolean compressionEnabled = (Boolean)this.sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
        CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(this.sparkConf);
        boolean fastMergeEnabled = (Boolean)this.sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE());
        boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
        boolean encryptionEnabled = this.blockManager.serializerManager().encryptionEnabled();
        ShuffleMapOutputWriter mapWriter = this.shuffleExecutorComponents.createMapOutputWriter(this.shuffleId, this.mapId, this.partitioner.numPartitions());
        try {
            if (fastMergeEnabled && fastMergeIsSupported) {
                if (this.transferToEnabled && !encryptionEnabled) {
                    logger.debug("Using transferTo-based fast merge");
                    this.mergeSpillsWithTransferTo(spills, mapWriter);
                } else {
                    logger.debug("Using fileStream-based fast merge");
                    this.mergeSpillsWithFileStream(spills, mapWriter, null);
                }
            } else {
                logger.debug("Using slow merge");
                this.mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
            }
            this.writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
            partitionLengths = mapWriter.commitAllPartitions();
        }
        catch (Exception e) {
            try {
                mapWriter.abort(e);
            }
            catch (Exception e2) {
                logger.warn("Failed to abort writing the map output.", (Throwable)e2);
                e.addSuppressed(e2);
            }
            throw e;
        }
        return partitionLengths;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void mergeSpillsWithFileStream(SpillInfo[] spills, ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException {
        int numPartitions = this.partitioner.numPartitions();
        InputStream[] spillInputStreams = new InputStream[spills.length];
        boolean threwException = true;
        try {
            for (int i = 0; i < spills.length; ++i) {
                spillInputStreams[i] = new NioBufferedFileInputStream(spills[i].file, this.inputBufferSizeInBytes);
            }
            for (int partition = 0; partition < numPartitions; ++partition) {
                boolean copyThrewException = true;
                ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
                OutputStream partitionOutput = writer.openStream();
                try {
                    partitionOutput = new TimeTrackingOutputStream(this.writeMetrics, partitionOutput);
                    partitionOutput = this.blockManager.serializerManager().wrapForEncryption(partitionOutput);
                    if (compressionCodec != null) {
                        partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
                    }
                    for (int i = 0; i < spills.length; ++i) {
                        long partitionLengthInSpill = spills[i].partitionLengths[partition];
                        if (partitionLengthInSpill <= 0L) continue;
                        Object partitionInputStream = null;
                        boolean copySpillThrewException = true;
                        try {
                            partitionInputStream = new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
                            partitionInputStream = this.blockManager.serializerManager().wrapForEncryption((InputStream)partitionInputStream);
                            if (compressionCodec != null) {
                                partitionInputStream = compressionCodec.compressedInputStream((InputStream)partitionInputStream);
                            }
                            ByteStreams.copy((InputStream)partitionInputStream, (OutputStream)partitionOutput);
                            copySpillThrewException = false;
                        }
                        catch (Throwable throwable) {
                            Closeables.close(partitionInputStream, (boolean)copySpillThrewException);
                            throw throwable;
                        }
                        Closeables.close((Closeable)partitionInputStream, (boolean)copySpillThrewException);
                    }
                    copyThrewException = false;
                }
                finally {
                    Closeables.close((Closeable)partitionOutput, (boolean)copyThrewException);
                }
                long numBytesWritten = writer.getNumBytesWritten();
                this.writeMetrics.incBytesWritten(numBytesWritten);
            }
            threwException = false;
        }
        finally {
            for (InputStream stream : spillInputStreams) {
                Closeables.close((Closeable)stream, (boolean)threwException);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void mergeSpillsWithTransferTo(SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException {
        int i;
        int numPartitions = this.partitioner.numPartitions();
        FileChannel[] spillInputChannels = new FileChannel[spills.length];
        long[] spillInputChannelPositions = new long[spills.length];
        boolean threwException = true;
        try {
            for (i = 0; i < spills.length; ++i) {
                spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
            }
            for (int partition = 0; partition < numPartitions; ++partition) {
                boolean copyThrewException = true;
                ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
                WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper().orElseGet(() -> new StreamFallbackChannelWrapper(UnsafeShuffleWriter.openStreamUnchecked(writer)));
                try {
                    int i2 = 0;
                    while (i2 < spills.length) {
                        long partitionLengthInSpill = spills[i2].partitionLengths[partition];
                        FileChannel spillInputChannel = spillInputChannels[i2];
                        long writeStartTime = System.nanoTime();
                        Utils.copyFileStreamNIO(spillInputChannel, resolvedChannel.channel(), spillInputChannelPositions[i2], partitionLengthInSpill);
                        copyThrewException = false;
                        int n = i2++;
                        spillInputChannelPositions[n] = spillInputChannelPositions[n] + partitionLengthInSpill;
                        this.writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
                    }
                }
                finally {
                    Closeables.close((Closeable)resolvedChannel, (boolean)copyThrewException);
                }
                long numBytes = writer.getNumBytesWritten();
                this.writeMetrics.incBytesWritten(numBytes);
            }
            threwException = false;
        }
        finally {
            for (i = 0; i < spills.length; ++i) {
                assert (spillInputChannelPositions[i] == spills[i].file.length());
                Closeables.close((Closeable)spillInputChannels[i], (boolean)threwException);
            }
        }
    }

    @Override
    public Option<MapStatus> stop(boolean success) {
        try {
            this.taskContext.taskMetrics().incPeakExecutionMemory(this.getPeakMemoryUsedBytes());
            if (this.stopping) {
                Option option = Option.apply(null);
                return option;
            }
            this.stopping = true;
            if (success) {
                if (this.mapStatus == null) {
                    throw new IllegalStateException("Cannot call stop(true) without having called write()");
                }
                Option option = Option.apply((Object)this.mapStatus);
                return option;
            }
            Option option = Option.apply(null);
            return option;
        }
        finally {
            if (this.sorter != null) {
                this.sorter.cleanupResources();
            }
        }
    }

    private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) {
        try {
            return writer.openStream();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static final class StreamFallbackChannelWrapper
    implements WritableByteChannelWrapper {
        private final WritableByteChannel channel;

        StreamFallbackChannelWrapper(OutputStream fallbackStream) {
            this.channel = Channels.newChannel(fallbackStream);
        }

        @Override
        public WritableByteChannel channel() {
            return this.channel;
        }

        @Override
        public void close() throws IOException {
            this.channel.close();
        }
    }

    private static final class MyByteArrayOutputStream
    extends ByteArrayOutputStream {
        MyByteArrayOutputStream(int size) {
            super(size);
        }

        public byte[] getBuf() {
            return this.buf;
        }
    }
}

