0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.shuffle.sort;
0019
0020 import java.nio.channels.Channels;
0021 import java.util.Optional;
0022 import javax.annotation.Nullable;
0023 import java.io.*;
0024 import java.nio.channels.FileChannel;
0025 import java.nio.channels.WritableByteChannel;
0026 import java.util.Iterator;
0027
0028 import scala.Option;
0029 import scala.Product2;
0030 import scala.collection.JavaConverters;
0031 import scala.reflect.ClassTag;
0032 import scala.reflect.ClassTag$;
0033
0034 import com.google.common.annotations.VisibleForTesting;
0035 import com.google.common.io.ByteStreams;
0036 import com.google.common.io.Closeables;
0037 import org.slf4j.Logger;
0038 import org.slf4j.LoggerFactory;
0039
0040 import org.apache.spark.*;
0041 import org.apache.spark.annotation.Private;
0042 import org.apache.spark.internal.config.package$;
0043 import org.apache.spark.io.CompressionCodec;
0044 import org.apache.spark.io.CompressionCodec$;
0045 import org.apache.spark.io.NioBufferedFileInputStream;
0046 import org.apache.spark.memory.TaskMemoryManager;
0047 import org.apache.spark.network.util.LimitedInputStream;
0048 import org.apache.spark.scheduler.MapStatus;
0049 import org.apache.spark.scheduler.MapStatus$;
0050 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
0051 import org.apache.spark.serializer.SerializationStream;
0052 import org.apache.spark.serializer.SerializerInstance;
0053 import org.apache.spark.shuffle.ShuffleWriter;
0054 import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
0055 import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
0056 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
0057 import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
0058 import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
0059 import org.apache.spark.storage.BlockManager;
0060 import org.apache.spark.storage.TimeTrackingOutputStream;
0061 import org.apache.spark.unsafe.Platform;
0062 import org.apache.spark.util.Utils;
0063
0064 @Private
0065 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
0066
0067 private static final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
0068
0069 private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
0070
0071 @VisibleForTesting
0072 static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
0073
0074 private final BlockManager blockManager;
0075 private final TaskMemoryManager memoryManager;
0076 private final SerializerInstance serializer;
0077 private final Partitioner partitioner;
0078 private final ShuffleWriteMetricsReporter writeMetrics;
0079 private final ShuffleExecutorComponents shuffleExecutorComponents;
0080 private final int shuffleId;
0081 private final long mapId;
0082 private final TaskContext taskContext;
0083 private final SparkConf sparkConf;
0084 private final boolean transferToEnabled;
0085 private final int initialSortBufferSize;
0086 private final int inputBufferSizeInBytes;
0087
0088 @Nullable private MapStatus mapStatus;
0089 @Nullable private ShuffleExternalSorter sorter;
0090 private long peakMemoryUsedBytes = 0;
0091
0092
0093 private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
0094 MyByteArrayOutputStream(int size) { super(size); }
0095 public byte[] getBuf() { return buf; }
0096 }
0097
0098 private MyByteArrayOutputStream serBuffer;
0099 private SerializationStream serOutputStream;
0100
0101
0102
0103
0104
0105
0106 private boolean stopping = false;
0107
0108 public UnsafeShuffleWriter(
0109 BlockManager blockManager,
0110 TaskMemoryManager memoryManager,
0111 SerializedShuffleHandle<K, V> handle,
0112 long mapId,
0113 TaskContext taskContext,
0114 SparkConf sparkConf,
0115 ShuffleWriteMetricsReporter writeMetrics,
0116 ShuffleExecutorComponents shuffleExecutorComponents) {
0117 final int numPartitions = handle.dependency().partitioner().numPartitions();
0118 if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
0119 throw new IllegalArgumentException(
0120 "UnsafeShuffleWriter can only be used for shuffles with at most " +
0121 SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
0122 " reduce partitions");
0123 }
0124 this.blockManager = blockManager;
0125 this.memoryManager = memoryManager;
0126 this.mapId = mapId;
0127 final ShuffleDependency<K, V, V> dep = handle.dependency();
0128 this.shuffleId = dep.shuffleId();
0129 this.serializer = dep.serializer().newInstance();
0130 this.partitioner = dep.partitioner();
0131 this.writeMetrics = writeMetrics;
0132 this.shuffleExecutorComponents = shuffleExecutorComponents;
0133 this.taskContext = taskContext;
0134 this.sparkConf = sparkConf;
0135 this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
0136 this.initialSortBufferSize =
0137 (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
0138 this.inputBufferSizeInBytes =
0139 (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
0140 open();
0141 }
0142
0143 private void updatePeakMemoryUsed() {
0144
0145 if (sorter != null) {
0146 long mem = sorter.getPeakMemoryUsedBytes();
0147 if (mem > peakMemoryUsedBytes) {
0148 peakMemoryUsedBytes = mem;
0149 }
0150 }
0151 }
0152
0153
0154
0155
0156 public long getPeakMemoryUsedBytes() {
0157 updatePeakMemoryUsed();
0158 return peakMemoryUsedBytes;
0159 }
0160
0161
0162
0163
0164 @VisibleForTesting
0165 public void write(Iterator<Product2<K, V>> records) throws IOException {
0166 write(JavaConverters.asScalaIteratorConverter(records).asScala());
0167 }
0168
0169 @Override
0170 public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
0171
0172
0173
0174 boolean success = false;
0175 try {
0176 while (records.hasNext()) {
0177 insertRecordIntoSorter(records.next());
0178 }
0179 closeAndWriteOutput();
0180 success = true;
0181 } finally {
0182 if (sorter != null) {
0183 try {
0184 sorter.cleanupResources();
0185 } catch (Exception e) {
0186
0187
0188 if (success) {
0189 throw e;
0190 } else {
0191 logger.error("In addition to a failure during writing, we failed during " +
0192 "cleanup.", e);
0193 }
0194 }
0195 }
0196 }
0197 }
0198
0199 private void open() {
0200 assert (sorter == null);
0201 sorter = new ShuffleExternalSorter(
0202 memoryManager,
0203 blockManager,
0204 taskContext,
0205 initialSortBufferSize,
0206 partitioner.numPartitions(),
0207 sparkConf,
0208 writeMetrics);
0209 serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
0210 serOutputStream = serializer.serializeStream(serBuffer);
0211 }
0212
0213 @VisibleForTesting
0214 void closeAndWriteOutput() throws IOException {
0215 assert(sorter != null);
0216 updatePeakMemoryUsed();
0217 serBuffer = null;
0218 serOutputStream = null;
0219 final SpillInfo[] spills = sorter.closeAndGetSpills();
0220 sorter = null;
0221 final long[] partitionLengths;
0222 try {
0223 partitionLengths = mergeSpills(spills);
0224 } finally {
0225 for (SpillInfo spill : spills) {
0226 if (spill.file.exists() && !spill.file.delete()) {
0227 logger.error("Error while deleting spill file {}", spill.file.getPath());
0228 }
0229 }
0230 }
0231 mapStatus = MapStatus$.MODULE$.apply(
0232 blockManager.shuffleServerId(), partitionLengths, mapId);
0233 }
0234
0235 @VisibleForTesting
0236 void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
0237 assert(sorter != null);
0238 final K key = record._1();
0239 final int partitionId = partitioner.getPartition(key);
0240 serBuffer.reset();
0241 serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
0242 serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
0243 serOutputStream.flush();
0244
0245 final int serializedRecordSize = serBuffer.size();
0246 assert (serializedRecordSize > 0);
0247
0248 sorter.insertRecord(
0249 serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
0250 }
0251
0252 @VisibleForTesting
0253 void forceSorterToSpill() throws IOException {
0254 assert (sorter != null);
0255 sorter.spill();
0256 }
0257
0258
0259
0260
0261
0262
0263
0264 private long[] mergeSpills(SpillInfo[] spills) throws IOException {
0265 long[] partitionLengths;
0266 if (spills.length == 0) {
0267 final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
0268 .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
0269 return mapWriter.commitAllPartitions();
0270 } else if (spills.length == 1) {
0271 Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter =
0272 shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
0273 if (maybeSingleFileWriter.isPresent()) {
0274
0275
0276 partitionLengths = spills[0].partitionLengths;
0277 maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
0278 } else {
0279 partitionLengths = mergeSpillsUsingStandardWriter(spills);
0280 }
0281 } else {
0282 partitionLengths = mergeSpillsUsingStandardWriter(spills);
0283 }
0284 return partitionLengths;
0285 }
0286
0287 private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException {
0288 long[] partitionLengths;
0289 final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
0290 final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
0291 final boolean fastMergeEnabled =
0292 (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE());
0293 final boolean fastMergeIsSupported = !compressionEnabled ||
0294 CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
0295 final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
0296 final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
0297 .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
0298 try {
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309 if (fastMergeEnabled && fastMergeIsSupported) {
0310
0311
0312
0313 if (transferToEnabled && !encryptionEnabled) {
0314 logger.debug("Using transferTo-based fast merge");
0315 mergeSpillsWithTransferTo(spills, mapWriter);
0316 } else {
0317 logger.debug("Using fileStream-based fast merge");
0318 mergeSpillsWithFileStream(spills, mapWriter, null);
0319 }
0320 } else {
0321 logger.debug("Using slow merge");
0322 mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
0323 }
0324
0325
0326
0327
0328
0329 writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
0330 partitionLengths = mapWriter.commitAllPartitions();
0331 } catch (Exception e) {
0332 try {
0333 mapWriter.abort(e);
0334 } catch (Exception e2) {
0335 logger.warn("Failed to abort writing the map output.", e2);
0336 e.addSuppressed(e2);
0337 }
0338 throw e;
0339 }
0340 return partitionLengths;
0341 }
0342
0343
0344
0345
0346
0347
0348
0349
0350
0351
0352
0353
0354
0355
0356
0357
0358
0359 private void mergeSpillsWithFileStream(
0360 SpillInfo[] spills,
0361 ShuffleMapOutputWriter mapWriter,
0362 @Nullable CompressionCodec compressionCodec) throws IOException {
0363 final int numPartitions = partitioner.numPartitions();
0364 final InputStream[] spillInputStreams = new InputStream[spills.length];
0365
0366 boolean threwException = true;
0367 try {
0368 for (int i = 0; i < spills.length; i++) {
0369 spillInputStreams[i] = new NioBufferedFileInputStream(
0370 spills[i].file,
0371 inputBufferSizeInBytes);
0372 }
0373 for (int partition = 0; partition < numPartitions; partition++) {
0374 boolean copyThrewException = true;
0375 ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
0376 OutputStream partitionOutput = writer.openStream();
0377 try {
0378 partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput);
0379 partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
0380 if (compressionCodec != null) {
0381 partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
0382 }
0383 for (int i = 0; i < spills.length; i++) {
0384 final long partitionLengthInSpill = spills[i].partitionLengths[partition];
0385
0386 if (partitionLengthInSpill > 0) {
0387 InputStream partitionInputStream = null;
0388 boolean copySpillThrewException = true;
0389 try {
0390 partitionInputStream = new LimitedInputStream(spillInputStreams[i],
0391 partitionLengthInSpill, false);
0392 partitionInputStream = blockManager.serializerManager().wrapForEncryption(
0393 partitionInputStream);
0394 if (compressionCodec != null) {
0395 partitionInputStream = compressionCodec.compressedInputStream(
0396 partitionInputStream);
0397 }
0398 ByteStreams.copy(partitionInputStream, partitionOutput);
0399 copySpillThrewException = false;
0400 } finally {
0401 Closeables.close(partitionInputStream, copySpillThrewException);
0402 }
0403 }
0404 }
0405 copyThrewException = false;
0406 } finally {
0407 Closeables.close(partitionOutput, copyThrewException);
0408 }
0409 long numBytesWritten = writer.getNumBytesWritten();
0410 writeMetrics.incBytesWritten(numBytesWritten);
0411 }
0412 threwException = false;
0413 } finally {
0414
0415
0416 for (InputStream stream : spillInputStreams) {
0417 Closeables.close(stream, threwException);
0418 }
0419 }
0420 }
0421
0422
0423
0424
0425
0426
0427
0428
0429
0430
0431 private void mergeSpillsWithTransferTo(
0432 SpillInfo[] spills,
0433 ShuffleMapOutputWriter mapWriter) throws IOException {
0434 final int numPartitions = partitioner.numPartitions();
0435 final FileChannel[] spillInputChannels = new FileChannel[spills.length];
0436 final long[] spillInputChannelPositions = new long[spills.length];
0437
0438 boolean threwException = true;
0439 try {
0440 for (int i = 0; i < spills.length; i++) {
0441 spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
0442 }
0443 for (int partition = 0; partition < numPartitions; partition++) {
0444 boolean copyThrewException = true;
0445 ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
0446 WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper()
0447 .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer)));
0448 try {
0449 for (int i = 0; i < spills.length; i++) {
0450 long partitionLengthInSpill = spills[i].partitionLengths[partition];
0451 final FileChannel spillInputChannel = spillInputChannels[i];
0452 final long writeStartTime = System.nanoTime();
0453 Utils.copyFileStreamNIO(
0454 spillInputChannel,
0455 resolvedChannel.channel(),
0456 spillInputChannelPositions[i],
0457 partitionLengthInSpill);
0458 copyThrewException = false;
0459 spillInputChannelPositions[i] += partitionLengthInSpill;
0460 writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
0461 }
0462 } finally {
0463 Closeables.close(resolvedChannel, copyThrewException);
0464 }
0465 long numBytes = writer.getNumBytesWritten();
0466 writeMetrics.incBytesWritten(numBytes);
0467 }
0468 threwException = false;
0469 } finally {
0470
0471
0472 for (int i = 0; i < spills.length; i++) {
0473 assert(spillInputChannelPositions[i] == spills[i].file.length());
0474 Closeables.close(spillInputChannels[i], threwException);
0475 }
0476 }
0477 }
0478
0479 @Override
0480 public Option<MapStatus> stop(boolean success) {
0481 try {
0482 taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
0483
0484 if (stopping) {
0485 return Option.apply(null);
0486 } else {
0487 stopping = true;
0488 if (success) {
0489 if (mapStatus == null) {
0490 throw new IllegalStateException("Cannot call stop(true) without having called write()");
0491 }
0492 return Option.apply(mapStatus);
0493 } else {
0494 return Option.apply(null);
0495 }
0496 }
0497 } finally {
0498 if (sorter != null) {
0499
0500
0501 sorter.cleanupResources();
0502 }
0503 }
0504 }
0505
0506 private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) {
0507 try {
0508 return writer.openStream();
0509 } catch (IOException e) {
0510 throw new RuntimeException(e);
0511 }
0512 }
0513
0514 private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper {
0515 private final WritableByteChannel channel;
0516
0517 StreamFallbackChannelWrapper(OutputStream fallbackStream) {
0518 this.channel = Channels.newChannel(fallbackStream);
0519 }
0520
0521 @Override
0522 public WritableByteChannel channel() {
0523 return channel;
0524 }
0525
0526 @Override
0527 public void close() throws IOException {
0528 channel.close();
0529 }
0530 }
0531 }