0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.shuffle.sort;
0019
0020 import javax.annotation.Nullable;
0021 import java.io.File;
0022 import java.io.IOException;
0023 import java.util.LinkedList;
0024
0025 import scala.Tuple2;
0026
0027 import com.google.common.annotations.VisibleForTesting;
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030
0031 import org.apache.spark.SparkConf;
0032 import org.apache.spark.TaskContext;
0033 import org.apache.spark.executor.ShuffleWriteMetrics;
0034 import org.apache.spark.internal.config.package$;
0035 import org.apache.spark.memory.MemoryConsumer;
0036 import org.apache.spark.memory.SparkOutOfMemoryError;
0037 import org.apache.spark.memory.TaskMemoryManager;
0038 import org.apache.spark.memory.TooLargePageException;
0039 import org.apache.spark.serializer.DummySerializerInstance;
0040 import org.apache.spark.serializer.SerializerInstance;
0041 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
0042 import org.apache.spark.storage.BlockManager;
0043 import org.apache.spark.storage.DiskBlockObjectWriter;
0044 import org.apache.spark.storage.FileSegment;
0045 import org.apache.spark.storage.TempShuffleBlockId;
0046 import org.apache.spark.unsafe.Platform;
0047 import org.apache.spark.unsafe.UnsafeAlignedOffset;
0048 import org.apache.spark.unsafe.array.LongArray;
0049 import org.apache.spark.unsafe.memory.MemoryBlock;
0050 import org.apache.spark.util.Utils;
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068 final class ShuffleExternalSorter extends MemoryConsumer {
0069
0070 private static final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
0071
0072 @VisibleForTesting
0073 static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
0074
0075 private final int numPartitions;
0076 private final TaskMemoryManager taskMemoryManager;
0077 private final BlockManager blockManager;
0078 private final TaskContext taskContext;
0079 private final ShuffleWriteMetricsReporter writeMetrics;
0080
0081
0082
0083
0084 private final int numElementsForSpillThreshold;
0085
0086
0087 private final int fileBufferSizeBytes;
0088
0089
0090 private final int diskWriteBufferSize;
0091
0092
0093
0094
0095
0096
0097
0098 private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>();
0099
0100 private final LinkedList<SpillInfo> spills = new LinkedList<>();
0101
0102
0103 private long peakMemoryUsedBytes;
0104
0105
0106 @Nullable private ShuffleInMemorySorter inMemSorter;
0107 @Nullable private MemoryBlock currentPage = null;
0108 private long pageCursor = -1;
0109
0110 ShuffleExternalSorter(
0111 TaskMemoryManager memoryManager,
0112 BlockManager blockManager,
0113 TaskContext taskContext,
0114 int initialSize,
0115 int numPartitions,
0116 SparkConf conf,
0117 ShuffleWriteMetricsReporter writeMetrics) {
0118 super(memoryManager,
0119 (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
0120 memoryManager.getTungstenMemoryMode());
0121 this.taskMemoryManager = memoryManager;
0122 this.blockManager = blockManager;
0123 this.taskContext = taskContext;
0124 this.numPartitions = numPartitions;
0125
0126 this.fileBufferSizeBytes =
0127 (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
0128 this.numElementsForSpillThreshold =
0129 (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD());
0130 this.writeMetrics = writeMetrics;
0131 this.inMemSorter = new ShuffleInMemorySorter(
0132 this, initialSize, (boolean) conf.get(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT()));
0133 this.peakMemoryUsedBytes = getMemoryUsage();
0134 this.diskWriteBufferSize =
0135 (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
0136 }
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146 private void writeSortedFile(boolean isLastFile) {
0147
0148
0149 final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
0150 inMemSorter.getSortedIterator();
0151
0152
0153 if (!sortedRecords.hasNext()) {
0154 return;
0155 }
0156
0157 final ShuffleWriteMetricsReporter writeMetricsToUse;
0158
0159 if (isLastFile) {
0160
0161 writeMetricsToUse = writeMetrics;
0162 } else {
0163
0164
0165
0166 writeMetricsToUse = new ShuffleWriteMetrics();
0167 }
0168
0169
0170
0171
0172
0173 final byte[] writeBuffer = new byte[diskWriteBufferSize];
0174
0175
0176
0177
0178 final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
0179 blockManager.diskBlockManager().createTempShuffleBlock();
0180 final File file = spilledFileInfo._2();
0181 final TempShuffleBlockId blockId = spilledFileInfo._1();
0182 final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
0183
0184
0185
0186
0187
0188 final SerializerInstance ser = DummySerializerInstance.INSTANCE;
0189
0190 int currentPartition = -1;
0191 final FileSegment committedSegment;
0192 try (DiskBlockObjectWriter writer =
0193 blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse)) {
0194
0195 final int uaoSize = UnsafeAlignedOffset.getUaoSize();
0196 while (sortedRecords.hasNext()) {
0197 sortedRecords.loadNext();
0198 final int partition = sortedRecords.packedRecordPointer.getPartitionId();
0199 assert (partition >= currentPartition);
0200 if (partition != currentPartition) {
0201
0202 if (currentPartition != -1) {
0203 final FileSegment fileSegment = writer.commitAndGet();
0204 spillInfo.partitionLengths[currentPartition] = fileSegment.length();
0205 }
0206 currentPartition = partition;
0207 }
0208
0209 final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
0210 final Object recordPage = taskMemoryManager.getPage(recordPointer);
0211 final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
0212 int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage);
0213 long recordReadPosition = recordOffsetInPage + uaoSize;
0214 while (dataRemaining > 0) {
0215 final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
0216 Platform.copyMemory(
0217 recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
0218 writer.write(writeBuffer, 0, toTransfer);
0219 recordReadPosition += toTransfer;
0220 dataRemaining -= toTransfer;
0221 }
0222 writer.recordWritten();
0223 }
0224
0225 committedSegment = writer.commitAndGet();
0226 }
0227
0228
0229
0230 if (currentPartition != -1) {
0231 spillInfo.partitionLengths[currentPartition] = committedSegment.length();
0232 spills.add(spillInfo);
0233 }
0234
0235 if (!isLastFile) {
0236
0237
0238
0239
0240
0241
0242
0243
0244
0245
0246
0247
0248
0249
0250
0251
0252
0253
0254 writeMetrics.incRecordsWritten(
0255 ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten());
0256 taskContext.taskMetrics().incDiskBytesSpilled(
0257 ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten());
0258 }
0259 }
0260
0261
0262
0263
0264 @Override
0265 public long spill(long size, MemoryConsumer trigger) throws IOException {
0266 if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
0267 return 0L;
0268 }
0269
0270 logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
0271 Thread.currentThread().getId(),
0272 Utils.bytesToString(getMemoryUsage()),
0273 spills.size(),
0274 spills.size() > 1 ? " times" : " time");
0275
0276 writeSortedFile(false);
0277 final long spillSize = freeMemory();
0278 inMemSorter.reset();
0279
0280
0281
0282 taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
0283 return spillSize;
0284 }
0285
0286 private long getMemoryUsage() {
0287 long totalPageSize = 0;
0288 for (MemoryBlock page : allocatedPages) {
0289 totalPageSize += page.size();
0290 }
0291 return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
0292 }
0293
0294 private void updatePeakMemoryUsed() {
0295 long mem = getMemoryUsage();
0296 if (mem > peakMemoryUsedBytes) {
0297 peakMemoryUsedBytes = mem;
0298 }
0299 }
0300
0301
0302
0303
0304 long getPeakMemoryUsedBytes() {
0305 updatePeakMemoryUsed();
0306 return peakMemoryUsedBytes;
0307 }
0308
0309 private long freeMemory() {
0310 updatePeakMemoryUsed();
0311 long memoryFreed = 0;
0312 for (MemoryBlock block : allocatedPages) {
0313 memoryFreed += block.size();
0314 freePage(block);
0315 }
0316 allocatedPages.clear();
0317 currentPage = null;
0318 pageCursor = 0;
0319 return memoryFreed;
0320 }
0321
0322
0323
0324
0325 public void cleanupResources() {
0326 freeMemory();
0327 if (inMemSorter != null) {
0328 inMemSorter.free();
0329 inMemSorter = null;
0330 }
0331 for (SpillInfo spill : spills) {
0332 if (spill.file.exists() && !spill.file.delete()) {
0333 logger.error("Unable to delete spill file {}", spill.file.getPath());
0334 }
0335 }
0336 }
0337
0338
0339
0340
0341
0342
0343 private void growPointerArrayIfNecessary() throws IOException {
0344 assert(inMemSorter != null);
0345 if (!inMemSorter.hasSpaceForAnotherRecord()) {
0346 long used = inMemSorter.getMemoryUsage();
0347 LongArray array;
0348 try {
0349
0350 array = allocateArray(used / 8 * 2);
0351 } catch (TooLargePageException e) {
0352
0353 spill();
0354 return;
0355 } catch (SparkOutOfMemoryError e) {
0356
0357 if (!inMemSorter.hasSpaceForAnotherRecord()) {
0358 logger.error("Unable to grow the pointer array");
0359 throw e;
0360 }
0361 return;
0362 }
0363
0364 if (inMemSorter.hasSpaceForAnotherRecord()) {
0365 freeArray(array);
0366 } else {
0367 inMemSorter.expandPointerArray(array);
0368 }
0369 }
0370 }
0371
0372
0373
0374
0375
0376
0377
0378
0379
0380
0381 private void acquireNewPageIfNecessary(int required) {
0382 if (currentPage == null ||
0383 pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) {
0384
0385 currentPage = allocatePage(required);
0386 pageCursor = currentPage.getBaseOffset();
0387 allocatedPages.add(currentPage);
0388 }
0389 }
0390
0391
0392
0393
0394 public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
0395 throws IOException {
0396
0397
0398 assert(inMemSorter != null);
0399 if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
0400 logger.info("Spilling data because number of spilledRecords crossed the threshold " +
0401 numElementsForSpillThreshold);
0402 spill();
0403 }
0404
0405 growPointerArrayIfNecessary();
0406 final int uaoSize = UnsafeAlignedOffset.getUaoSize();
0407
0408 final int required = length + uaoSize;
0409 acquireNewPageIfNecessary(required);
0410
0411 assert(currentPage != null);
0412 final Object base = currentPage.getBaseObject();
0413 final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
0414 UnsafeAlignedOffset.putSize(base, pageCursor, length);
0415 pageCursor += uaoSize;
0416 Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
0417 pageCursor += length;
0418 inMemSorter.insertRecord(recordAddress, partitionId);
0419 }
0420
0421
0422
0423
0424
0425
0426
0427 public SpillInfo[] closeAndGetSpills() throws IOException {
0428 if (inMemSorter != null) {
0429
0430 writeSortedFile(true);
0431 freeMemory();
0432 inMemSorter.free();
0433 inMemSorter = null;
0434 }
0435 return spills.toArray(new SpillInfo[spills.size()]);
0436 }
0437
0438 }