Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.shuffle.sort;
0019 
0020 import java.nio.channels.Channels;
0021 import java.util.Optional;
0022 import javax.annotation.Nullable;
0023 import java.io.*;
0024 import java.nio.channels.FileChannel;
0025 import java.nio.channels.WritableByteChannel;
0026 import java.util.Iterator;
0027 
0028 import scala.Option;
0029 import scala.Product2;
0030 import scala.collection.JavaConverters;
0031 import scala.reflect.ClassTag;
0032 import scala.reflect.ClassTag$;
0033 
0034 import com.google.common.annotations.VisibleForTesting;
0035 import com.google.common.io.ByteStreams;
0036 import com.google.common.io.Closeables;
0037 import org.slf4j.Logger;
0038 import org.slf4j.LoggerFactory;
0039 
0040 import org.apache.spark.*;
0041 import org.apache.spark.annotation.Private;
0042 import org.apache.spark.internal.config.package$;
0043 import org.apache.spark.io.CompressionCodec;
0044 import org.apache.spark.io.CompressionCodec$;
0045 import org.apache.spark.io.NioBufferedFileInputStream;
0046 import org.apache.spark.memory.TaskMemoryManager;
0047 import org.apache.spark.network.util.LimitedInputStream;
0048 import org.apache.spark.scheduler.MapStatus;
0049 import org.apache.spark.scheduler.MapStatus$;
0050 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
0051 import org.apache.spark.serializer.SerializationStream;
0052 import org.apache.spark.serializer.SerializerInstance;
0053 import org.apache.spark.shuffle.ShuffleWriter;
0054 import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
0055 import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
0056 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
0057 import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
0058 import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
0059 import org.apache.spark.storage.BlockManager;
0060 import org.apache.spark.storage.TimeTrackingOutputStream;
0061 import org.apache.spark.unsafe.Platform;
0062 import org.apache.spark.util.Utils;
0063 
0064 @Private
0065 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
0066 
0067   private static final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
0068 
0069   private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
0070 
0071   @VisibleForTesting
0072   static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
0073 
0074   private final BlockManager blockManager;
0075   private final TaskMemoryManager memoryManager;
0076   private final SerializerInstance serializer;
0077   private final Partitioner partitioner;
0078   private final ShuffleWriteMetricsReporter writeMetrics;
0079   private final ShuffleExecutorComponents shuffleExecutorComponents;
0080   private final int shuffleId;
0081   private final long mapId;
0082   private final TaskContext taskContext;
0083   private final SparkConf sparkConf;
0084   private final boolean transferToEnabled;
0085   private final int initialSortBufferSize;
0086   private final int inputBufferSizeInBytes;
0087 
0088   @Nullable private MapStatus mapStatus;
0089   @Nullable private ShuffleExternalSorter sorter;
0090   private long peakMemoryUsedBytes = 0;
0091 
0092   /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
0093   private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
0094     MyByteArrayOutputStream(int size) { super(size); }
0095     public byte[] getBuf() { return buf; }
0096   }
0097 
0098   private MyByteArrayOutputStream serBuffer;
0099   private SerializationStream serOutputStream;
0100 
0101   /**
0102    * Are we in the process of stopping? Because map tasks can call stop() with success = true
0103    * and then call stop() with success = false if they get an exception, we want to make sure
0104    * we don't try deleting files, etc twice.
0105    */
0106   private boolean stopping = false;
0107 
0108   public UnsafeShuffleWriter(
0109       BlockManager blockManager,
0110       TaskMemoryManager memoryManager,
0111       SerializedShuffleHandle<K, V> handle,
0112       long mapId,
0113       TaskContext taskContext,
0114       SparkConf sparkConf,
0115       ShuffleWriteMetricsReporter writeMetrics,
0116       ShuffleExecutorComponents shuffleExecutorComponents) {
0117     final int numPartitions = handle.dependency().partitioner().numPartitions();
0118     if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
0119       throw new IllegalArgumentException(
0120         "UnsafeShuffleWriter can only be used for shuffles with at most " +
0121         SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
0122         " reduce partitions");
0123     }
0124     this.blockManager = blockManager;
0125     this.memoryManager = memoryManager;
0126     this.mapId = mapId;
0127     final ShuffleDependency<K, V, V> dep = handle.dependency();
0128     this.shuffleId = dep.shuffleId();
0129     this.serializer = dep.serializer().newInstance();
0130     this.partitioner = dep.partitioner();
0131     this.writeMetrics = writeMetrics;
0132     this.shuffleExecutorComponents = shuffleExecutorComponents;
0133     this.taskContext = taskContext;
0134     this.sparkConf = sparkConf;
0135     this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
0136     this.initialSortBufferSize =
0137       (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
0138     this.inputBufferSizeInBytes =
0139       (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
0140     open();
0141   }
0142 
0143   private void updatePeakMemoryUsed() {
0144     // sorter can be null if this writer is closed
0145     if (sorter != null) {
0146       long mem = sorter.getPeakMemoryUsedBytes();
0147       if (mem > peakMemoryUsedBytes) {
0148         peakMemoryUsedBytes = mem;
0149       }
0150     }
0151   }
0152 
0153   /**
0154    * Return the peak memory used so far, in bytes.
0155    */
0156   public long getPeakMemoryUsedBytes() {
0157     updatePeakMemoryUsed();
0158     return peakMemoryUsedBytes;
0159   }
0160 
0161   /**
0162    * This convenience method should only be called in test code.
0163    */
0164   @VisibleForTesting
0165   public void write(Iterator<Product2<K, V>> records) throws IOException {
0166     write(JavaConverters.asScalaIteratorConverter(records).asScala());
0167   }
0168 
0169   @Override
0170   public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
0171     // Keep track of success so we know if we encountered an exception
0172     // We do this rather than a standard try/catch/re-throw to handle
0173     // generic throwables.
0174     boolean success = false;
0175     try {
0176       while (records.hasNext()) {
0177         insertRecordIntoSorter(records.next());
0178       }
0179       closeAndWriteOutput();
0180       success = true;
0181     } finally {
0182       if (sorter != null) {
0183         try {
0184           sorter.cleanupResources();
0185         } catch (Exception e) {
0186           // Only throw this error if we won't be masking another
0187           // error.
0188           if (success) {
0189             throw e;
0190           } else {
0191             logger.error("In addition to a failure during writing, we failed during " +
0192                          "cleanup.", e);
0193           }
0194         }
0195       }
0196     }
0197   }
0198 
0199   private void open() {
0200     assert (sorter == null);
0201     sorter = new ShuffleExternalSorter(
0202       memoryManager,
0203       blockManager,
0204       taskContext,
0205       initialSortBufferSize,
0206       partitioner.numPartitions(),
0207       sparkConf,
0208       writeMetrics);
0209     serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
0210     serOutputStream = serializer.serializeStream(serBuffer);
0211   }
0212 
0213   @VisibleForTesting
0214   void closeAndWriteOutput() throws IOException {
0215     assert(sorter != null);
0216     updatePeakMemoryUsed();
0217     serBuffer = null;
0218     serOutputStream = null;
0219     final SpillInfo[] spills = sorter.closeAndGetSpills();
0220     sorter = null;
0221     final long[] partitionLengths;
0222     try {
0223       partitionLengths = mergeSpills(spills);
0224     } finally {
0225       for (SpillInfo spill : spills) {
0226         if (spill.file.exists() && !spill.file.delete()) {
0227           logger.error("Error while deleting spill file {}", spill.file.getPath());
0228         }
0229       }
0230     }
0231     mapStatus = MapStatus$.MODULE$.apply(
0232       blockManager.shuffleServerId(), partitionLengths, mapId);
0233   }
0234 
0235   @VisibleForTesting
0236   void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
0237     assert(sorter != null);
0238     final K key = record._1();
0239     final int partitionId = partitioner.getPartition(key);
0240     serBuffer.reset();
0241     serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
0242     serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
0243     serOutputStream.flush();
0244 
0245     final int serializedRecordSize = serBuffer.size();
0246     assert (serializedRecordSize > 0);
0247 
0248     sorter.insertRecord(
0249       serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
0250   }
0251 
0252   @VisibleForTesting
0253   void forceSorterToSpill() throws IOException {
0254     assert (sorter != null);
0255     sorter.spill();
0256   }
0257 
0258   /**
0259    * Merge zero or more spill files together, choosing the fastest merging strategy based on the
0260    * number of spills and the IO compression codec.
0261    *
0262    * @return the partition lengths in the merged file.
0263    */
0264   private long[] mergeSpills(SpillInfo[] spills) throws IOException {
0265     long[] partitionLengths;
0266     if (spills.length == 0) {
0267       final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
0268           .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
0269       return mapWriter.commitAllPartitions();
0270     } else if (spills.length == 1) {
0271       Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter =
0272           shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
0273       if (maybeSingleFileWriter.isPresent()) {
0274         // Here, we don't need to perform any metrics updates because the bytes written to this
0275         // output file would have already been counted as shuffle bytes written.
0276         partitionLengths = spills[0].partitionLengths;
0277         maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
0278       } else {
0279         partitionLengths = mergeSpillsUsingStandardWriter(spills);
0280       }
0281     } else {
0282       partitionLengths = mergeSpillsUsingStandardWriter(spills);
0283     }
0284     return partitionLengths;
0285   }
0286 
0287   private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException {
0288     long[] partitionLengths;
0289     final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
0290     final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
0291     final boolean fastMergeEnabled =
0292         (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE());
0293     final boolean fastMergeIsSupported = !compressionEnabled ||
0294         CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
0295     final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
0296     final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
0297         .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
0298     try {
0299       // There are multiple spills to merge, so none of these spill files' lengths were counted
0300       // towards our shuffle write count or shuffle write time. If we use the slow merge path,
0301       // then the final output file's size won't necessarily be equal to the sum of the spill
0302       // files' sizes. To guard against this case, we look at the output file's actual size when
0303       // computing shuffle bytes written.
0304       //
0305       // We allow the individual merge methods to report their own IO times since different merge
0306       // strategies use different IO techniques.  We count IO during merge towards the shuffle
0307       // write time, which appears to be consistent with the "not bypassing merge-sort" branch in
0308       // ExternalSorter.
0309       if (fastMergeEnabled && fastMergeIsSupported) {
0310         // Compression is disabled or we are using an IO compression codec that supports
0311         // decompression of concatenated compressed streams, so we can perform a fast spill merge
0312         // that doesn't need to interpret the spilled bytes.
0313         if (transferToEnabled && !encryptionEnabled) {
0314           logger.debug("Using transferTo-based fast merge");
0315           mergeSpillsWithTransferTo(spills, mapWriter);
0316         } else {
0317           logger.debug("Using fileStream-based fast merge");
0318           mergeSpillsWithFileStream(spills, mapWriter, null);
0319         }
0320       } else {
0321         logger.debug("Using slow merge");
0322         mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
0323       }
0324       // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
0325       // in-memory records, we write out the in-memory records to a file but do not count that
0326       // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
0327       // to be counted as shuffle write, but this will lead to double-counting of the final
0328       // SpillInfo's bytes.
0329       writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
0330       partitionLengths = mapWriter.commitAllPartitions();
0331     } catch (Exception e) {
0332       try {
0333         mapWriter.abort(e);
0334       } catch (Exception e2) {
0335         logger.warn("Failed to abort writing the map output.", e2);
0336         e.addSuppressed(e2);
0337       }
0338       throw e;
0339     }
0340     return partitionLengths;
0341   }
0342 
0343   /**
0344    * Merges spill files using Java FileStreams. This code path is typically slower than
0345    * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[],
0346    * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec
0347    * does not support concatenation of compressed data, when encryption is enabled, or when
0348    * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
0349    * This code path might also be faster in cases where individual partition size in a spill
0350    * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small
0351    * disk ios which is inefficient. In those case, Using large buffers for input and output
0352    * files helps reducing the number of disk ios, making the file merging faster.
0353    *
0354    * @param spills the spills to merge.
0355    * @param mapWriter the map output writer to use for output.
0356    * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
0357    * @return the partition lengths in the merged file.
0358    */
0359   private void mergeSpillsWithFileStream(
0360       SpillInfo[] spills,
0361       ShuffleMapOutputWriter mapWriter,
0362       @Nullable CompressionCodec compressionCodec) throws IOException {
0363     final int numPartitions = partitioner.numPartitions();
0364     final InputStream[] spillInputStreams = new InputStream[spills.length];
0365 
0366     boolean threwException = true;
0367     try {
0368       for (int i = 0; i < spills.length; i++) {
0369         spillInputStreams[i] = new NioBufferedFileInputStream(
0370           spills[i].file,
0371           inputBufferSizeInBytes);
0372       }
0373       for (int partition = 0; partition < numPartitions; partition++) {
0374         boolean copyThrewException = true;
0375         ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
0376         OutputStream partitionOutput = writer.openStream();
0377         try {
0378           partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput);
0379           partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
0380           if (compressionCodec != null) {
0381             partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
0382           }
0383           for (int i = 0; i < spills.length; i++) {
0384             final long partitionLengthInSpill = spills[i].partitionLengths[partition];
0385 
0386             if (partitionLengthInSpill > 0) {
0387               InputStream partitionInputStream = null;
0388               boolean copySpillThrewException = true;
0389               try {
0390                 partitionInputStream = new LimitedInputStream(spillInputStreams[i],
0391                     partitionLengthInSpill, false);
0392                 partitionInputStream = blockManager.serializerManager().wrapForEncryption(
0393                     partitionInputStream);
0394                 if (compressionCodec != null) {
0395                   partitionInputStream = compressionCodec.compressedInputStream(
0396                       partitionInputStream);
0397                 }
0398                 ByteStreams.copy(partitionInputStream, partitionOutput);
0399                 copySpillThrewException = false;
0400               } finally {
0401                 Closeables.close(partitionInputStream, copySpillThrewException);
0402               }
0403             }
0404           }
0405           copyThrewException = false;
0406         } finally {
0407           Closeables.close(partitionOutput, copyThrewException);
0408         }
0409         long numBytesWritten = writer.getNumBytesWritten();
0410         writeMetrics.incBytesWritten(numBytesWritten);
0411       }
0412       threwException = false;
0413     } finally {
0414       // To avoid masking exceptions that caused us to prematurely enter the finally block, only
0415       // throw exceptions during cleanup if threwException == false.
0416       for (InputStream stream : spillInputStreams) {
0417         Closeables.close(stream, threwException);
0418       }
0419     }
0420   }
0421 
0422   /**
0423    * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
0424    * This is only safe when the IO compression codec and serializer support concatenation of
0425    * serialized streams.
0426    *
0427    * @param spills the spills to merge.
0428    * @param mapWriter the map output writer to use for output.
0429    * @return the partition lengths in the merged file.
0430    */
0431   private void mergeSpillsWithTransferTo(
0432       SpillInfo[] spills,
0433       ShuffleMapOutputWriter mapWriter) throws IOException {
0434     final int numPartitions = partitioner.numPartitions();
0435     final FileChannel[] spillInputChannels = new FileChannel[spills.length];
0436     final long[] spillInputChannelPositions = new long[spills.length];
0437 
0438     boolean threwException = true;
0439     try {
0440       for (int i = 0; i < spills.length; i++) {
0441         spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
0442       }
0443       for (int partition = 0; partition < numPartitions; partition++) {
0444         boolean copyThrewException = true;
0445         ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
0446         WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper()
0447             .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer)));
0448         try {
0449           for (int i = 0; i < spills.length; i++) {
0450             long partitionLengthInSpill = spills[i].partitionLengths[partition];
0451             final FileChannel spillInputChannel = spillInputChannels[i];
0452             final long writeStartTime = System.nanoTime();
0453             Utils.copyFileStreamNIO(
0454                 spillInputChannel,
0455                 resolvedChannel.channel(),
0456                 spillInputChannelPositions[i],
0457                 partitionLengthInSpill);
0458             copyThrewException = false;
0459             spillInputChannelPositions[i] += partitionLengthInSpill;
0460             writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
0461           }
0462         } finally {
0463           Closeables.close(resolvedChannel, copyThrewException);
0464         }
0465         long numBytes = writer.getNumBytesWritten();
0466         writeMetrics.incBytesWritten(numBytes);
0467       }
0468       threwException = false;
0469     } finally {
0470       // To avoid masking exceptions that caused us to prematurely enter the finally block, only
0471       // throw exceptions during cleanup if threwException == false.
0472       for (int i = 0; i < spills.length; i++) {
0473         assert(spillInputChannelPositions[i] == spills[i].file.length());
0474         Closeables.close(spillInputChannels[i], threwException);
0475       }
0476     }
0477   }
0478 
0479   @Override
0480   public Option<MapStatus> stop(boolean success) {
0481     try {
0482       taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
0483 
0484       if (stopping) {
0485         return Option.apply(null);
0486       } else {
0487         stopping = true;
0488         if (success) {
0489           if (mapStatus == null) {
0490             throw new IllegalStateException("Cannot call stop(true) without having called write()");
0491           }
0492           return Option.apply(mapStatus);
0493         } else {
0494           return Option.apply(null);
0495         }
0496       }
0497     } finally {
0498       if (sorter != null) {
0499         // If sorter is non-null, then this implies that we called stop() in response to an error,
0500         // so we need to clean up memory and spill files created by the sorter
0501         sorter.cleanupResources();
0502       }
0503     }
0504   }
0505 
0506   private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) {
0507     try {
0508       return writer.openStream();
0509     } catch (IOException e) {
0510       throw new RuntimeException(e);
0511     }
0512   }
0513 
0514   private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper {
0515     private final WritableByteChannel channel;
0516 
0517     StreamFallbackChannelWrapper(OutputStream fallbackStream) {
0518       this.channel = Channels.newChannel(fallbackStream);
0519     }
0520 
0521     @Override
0522     public WritableByteChannel channel() {
0523       return channel;
0524     }
0525 
0526     @Override
0527     public void close() throws IOException {
0528       channel.close();
0529     }
0530   }
0531 }