0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.util.collection.unsafe.sort;
0019
0020 import javax.annotation.Nullable;
0021 import java.io.File;
0022 import java.io.IOException;
0023 import java.util.LinkedList;
0024 import java.util.Queue;
0025 import java.util.function.Supplier;
0026
0027 import com.google.common.annotations.VisibleForTesting;
0028 import org.apache.spark.memory.SparkOutOfMemoryError;
0029 import org.slf4j.Logger;
0030 import org.slf4j.LoggerFactory;
0031
0032 import org.apache.spark.TaskContext;
0033 import org.apache.spark.executor.ShuffleWriteMetrics;
0034 import org.apache.spark.memory.MemoryConsumer;
0035 import org.apache.spark.memory.TaskMemoryManager;
0036 import org.apache.spark.memory.TooLargePageException;
0037 import org.apache.spark.serializer.SerializerManager;
0038 import org.apache.spark.storage.BlockManager;
0039 import org.apache.spark.unsafe.Platform;
0040 import org.apache.spark.unsafe.UnsafeAlignedOffset;
0041 import org.apache.spark.unsafe.array.LongArray;
0042 import org.apache.spark.unsafe.memory.MemoryBlock;
0043 import org.apache.spark.util.Utils;
0044
0045
0046
0047
0048 public final class UnsafeExternalSorter extends MemoryConsumer {
0049
0050 private static final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
0051
0052 @Nullable
0053 private final PrefixComparator prefixComparator;
0054
0055
0056
0057
0058
0059
0060
0061 @Nullable
0062 private final Supplier<RecordComparator> recordComparatorSupplier;
0063
0064 private final TaskMemoryManager taskMemoryManager;
0065 private final BlockManager blockManager;
0066 private final SerializerManager serializerManager;
0067 private final TaskContext taskContext;
0068
0069
0070 private final int fileBufferSizeBytes;
0071
0072
0073
0074
0075 private final int numElementsForSpillThreshold;
0076
0077
0078
0079
0080
0081
0082
0083 private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>();
0084
0085 private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
0086
0087
0088 @Nullable private volatile UnsafeInMemorySorter inMemSorter;
0089
0090 private MemoryBlock currentPage = null;
0091 private long pageCursor = -1;
0092 private long peakMemoryUsedBytes = 0;
0093 private long totalSpillBytes = 0L;
0094 private long totalSortTimeNanos = 0L;
0095 private volatile SpillableIterator readingIterator = null;
0096
0097 public static UnsafeExternalSorter createWithExistingInMemorySorter(
0098 TaskMemoryManager taskMemoryManager,
0099 BlockManager blockManager,
0100 SerializerManager serializerManager,
0101 TaskContext taskContext,
0102 Supplier<RecordComparator> recordComparatorSupplier,
0103 PrefixComparator prefixComparator,
0104 int initialSize,
0105 long pageSizeBytes,
0106 int numElementsForSpillThreshold,
0107 UnsafeInMemorySorter inMemorySorter) throws IOException {
0108 UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
0109 serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize,
0110 pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false );
0111 sorter.spill(Long.MAX_VALUE, sorter);
0112
0113 sorter.inMemSorter = null;
0114 return sorter;
0115 }
0116
0117 public static UnsafeExternalSorter create(
0118 TaskMemoryManager taskMemoryManager,
0119 BlockManager blockManager,
0120 SerializerManager serializerManager,
0121 TaskContext taskContext,
0122 Supplier<RecordComparator> recordComparatorSupplier,
0123 PrefixComparator prefixComparator,
0124 int initialSize,
0125 long pageSizeBytes,
0126 int numElementsForSpillThreshold,
0127 boolean canUseRadixSort) {
0128 return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
0129 taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes,
0130 numElementsForSpillThreshold, null, canUseRadixSort);
0131 }
0132
0133 private UnsafeExternalSorter(
0134 TaskMemoryManager taskMemoryManager,
0135 BlockManager blockManager,
0136 SerializerManager serializerManager,
0137 TaskContext taskContext,
0138 Supplier<RecordComparator> recordComparatorSupplier,
0139 PrefixComparator prefixComparator,
0140 int initialSize,
0141 long pageSizeBytes,
0142 int numElementsForSpillThreshold,
0143 @Nullable UnsafeInMemorySorter existingInMemorySorter,
0144 boolean canUseRadixSort) {
0145 super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
0146 this.taskMemoryManager = taskMemoryManager;
0147 this.blockManager = blockManager;
0148 this.serializerManager = serializerManager;
0149 this.taskContext = taskContext;
0150 this.recordComparatorSupplier = recordComparatorSupplier;
0151 this.prefixComparator = prefixComparator;
0152
0153
0154 this.fileBufferSizeBytes = 32 * 1024;
0155
0156 if (existingInMemorySorter == null) {
0157 RecordComparator comparator = null;
0158 if (recordComparatorSupplier != null) {
0159 comparator = recordComparatorSupplier.get();
0160 }
0161 this.inMemSorter = new UnsafeInMemorySorter(
0162 this,
0163 taskMemoryManager,
0164 comparator,
0165 prefixComparator,
0166 initialSize,
0167 canUseRadixSort);
0168 } else {
0169 this.inMemSorter = existingInMemorySorter;
0170 }
0171 this.peakMemoryUsedBytes = getMemoryUsage();
0172 this.numElementsForSpillThreshold = numElementsForSpillThreshold;
0173
0174
0175
0176
0177 taskContext.addTaskCompletionListener(context -> {
0178 cleanupResources();
0179 });
0180 }
0181
0182
0183
0184
0185
0186 @VisibleForTesting
0187 public void closeCurrentPage() {
0188 if (currentPage != null) {
0189 pageCursor = currentPage.getBaseOffset() + currentPage.size();
0190 }
0191 }
0192
0193
0194
0195
0196 @Override
0197 public long spill(long size, MemoryConsumer trigger) throws IOException {
0198 if (trigger != this) {
0199 if (readingIterator != null) {
0200 return readingIterator.spill();
0201 }
0202 return 0L;
0203 }
0204
0205 if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
0206 return 0L;
0207 }
0208
0209 logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
0210 Thread.currentThread().getId(),
0211 Utils.bytesToString(getMemoryUsage()),
0212 spillWriters.size(),
0213 spillWriters.size() > 1 ? " times" : " time");
0214
0215 ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
0216
0217 final UnsafeSorterSpillWriter spillWriter =
0218 new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
0219 inMemSorter.numRecords());
0220 spillWriters.add(spillWriter);
0221 spillIterator(inMemSorter.getSortedIterator(), spillWriter);
0222
0223 final long spillSize = freeMemory();
0224
0225
0226
0227 inMemSorter.reset();
0228
0229
0230
0231
0232 taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
0233 taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten());
0234 totalSpillBytes += spillSize;
0235 return spillSize;
0236 }
0237
0238
0239
0240
0241
0242 private long getMemoryUsage() {
0243 long totalPageSize = 0;
0244 for (MemoryBlock page : allocatedPages) {
0245 totalPageSize += page.size();
0246 }
0247 return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
0248 }
0249
0250 private void updatePeakMemoryUsed() {
0251 long mem = getMemoryUsage();
0252 if (mem > peakMemoryUsedBytes) {
0253 peakMemoryUsedBytes = mem;
0254 }
0255 }
0256
0257
0258
0259
0260 public long getPeakMemoryUsedBytes() {
0261 updatePeakMemoryUsed();
0262 return peakMemoryUsedBytes;
0263 }
0264
0265
0266
0267
0268 public long getSortTimeNanos() {
0269 UnsafeInMemorySorter sorter = inMemSorter;
0270 if (sorter != null) {
0271 return sorter.getSortTimeNanos();
0272 }
0273 return totalSortTimeNanos;
0274 }
0275
0276
0277
0278
0279 public long getSpillSize() {
0280 return totalSpillBytes;
0281 }
0282
0283 @VisibleForTesting
0284 public int getNumberOfAllocatedPages() {
0285 return allocatedPages.size();
0286 }
0287
0288
0289
0290
0291
0292
0293 private long freeMemory() {
0294 updatePeakMemoryUsed();
0295 long memoryFreed = 0;
0296 for (MemoryBlock block : allocatedPages) {
0297 memoryFreed += block.size();
0298 freePage(block);
0299 }
0300 allocatedPages.clear();
0301 currentPage = null;
0302 pageCursor = 0;
0303 return memoryFreed;
0304 }
0305
0306
0307
0308
0309 private void deleteSpillFiles() {
0310 for (UnsafeSorterSpillWriter spill : spillWriters) {
0311 File file = spill.getFile();
0312 if (file != null && file.exists()) {
0313 if (!file.delete()) {
0314 logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
0315 }
0316 }
0317 }
0318 }
0319
0320
0321
0322
0323 public void cleanupResources() {
0324 synchronized (this) {
0325 deleteSpillFiles();
0326 freeMemory();
0327 if (inMemSorter != null) {
0328 inMemSorter.free();
0329 inMemSorter = null;
0330 }
0331 }
0332 }
0333
0334
0335
0336
0337
0338
0339 private void growPointerArrayIfNecessary() throws IOException {
0340 assert(inMemSorter != null);
0341 if (!inMemSorter.hasSpaceForAnotherRecord()) {
0342 long used = inMemSorter.getMemoryUsage();
0343 LongArray array;
0344 try {
0345
0346 array = allocateArray(used / 8 * 2);
0347 } catch (TooLargePageException e) {
0348
0349 spill();
0350 return;
0351 } catch (SparkOutOfMemoryError e) {
0352
0353 if (!inMemSorter.hasSpaceForAnotherRecord()) {
0354 logger.error("Unable to grow the pointer array");
0355 throw e;
0356 }
0357 return;
0358 }
0359
0360 if (inMemSorter.hasSpaceForAnotherRecord()) {
0361 freeArray(array);
0362 } else {
0363 inMemSorter.expandPointerArray(array);
0364 }
0365 }
0366 }
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377 private void acquireNewPageIfNecessary(int required) {
0378 if (currentPage == null ||
0379 pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
0380
0381 currentPage = allocatePage(required);
0382 pageCursor = currentPage.getBaseOffset();
0383 allocatedPages.add(currentPage);
0384 }
0385 }
0386
0387
0388
0389
0390 public void insertRecord(
0391 Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull)
0392 throws IOException {
0393
0394 assert(inMemSorter != null);
0395 if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
0396 logger.info("Spilling data because number of spilledRecords crossed the threshold " +
0397 numElementsForSpillThreshold);
0398 spill();
0399 }
0400
0401 growPointerArrayIfNecessary();
0402 int uaoSize = UnsafeAlignedOffset.getUaoSize();
0403
0404 final int required = length + uaoSize;
0405 acquireNewPageIfNecessary(required);
0406
0407 final Object base = currentPage.getBaseObject();
0408 final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
0409 UnsafeAlignedOffset.putSize(base, pageCursor, length);
0410 pageCursor += uaoSize;
0411 Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
0412 pageCursor += length;
0413 inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
0414 }
0415
0416
0417
0418
0419
0420
0421
0422
0423
0424 public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
0425 Object valueBase, long valueOffset, int valueLen, long prefix, boolean prefixIsNull)
0426 throws IOException {
0427
0428 growPointerArrayIfNecessary();
0429 int uaoSize = UnsafeAlignedOffset.getUaoSize();
0430 final int required = keyLen + valueLen + (2 * uaoSize);
0431 acquireNewPageIfNecessary(required);
0432
0433 final Object base = currentPage.getBaseObject();
0434 final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
0435 UnsafeAlignedOffset.putSize(base, pageCursor, keyLen + valueLen + uaoSize);
0436 pageCursor += uaoSize;
0437 UnsafeAlignedOffset.putSize(base, pageCursor, keyLen);
0438 pageCursor += uaoSize;
0439 Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
0440 pageCursor += keyLen;
0441 Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
0442 pageCursor += valueLen;
0443
0444 assert(inMemSorter != null);
0445 inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
0446 }
0447
0448
0449
0450
0451 public void merge(UnsafeExternalSorter other) throws IOException {
0452 other.spill();
0453 spillWriters.addAll(other.spillWriters);
0454
0455 other.spillWriters.clear();
0456 other.cleanupResources();
0457 }
0458
0459
0460
0461
0462
0463 public UnsafeSorterIterator getSortedIterator() throws IOException {
0464 assert(recordComparatorSupplier != null);
0465 if (spillWriters.isEmpty()) {
0466 assert(inMemSorter != null);
0467 readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
0468 return readingIterator;
0469 } else {
0470 final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(
0471 recordComparatorSupplier.get(), prefixComparator, spillWriters.size());
0472 for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
0473 spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
0474 }
0475 if (inMemSorter != null) {
0476 readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
0477 spillMerger.addSpillIfNotEmpty(readingIterator);
0478 }
0479 return spillMerger.getSortedIterator();
0480 }
0481 }
0482
0483 @VisibleForTesting boolean hasSpaceForAnotherRecord() {
0484 return inMemSorter.hasSpaceForAnotherRecord();
0485 }
0486
0487 private static void spillIterator(UnsafeSorterIterator inMemIterator,
0488 UnsafeSorterSpillWriter spillWriter) throws IOException {
0489 while (inMemIterator.hasNext()) {
0490 inMemIterator.loadNext();
0491 final Object baseObject = inMemIterator.getBaseObject();
0492 final long baseOffset = inMemIterator.getBaseOffset();
0493 final int recordLength = inMemIterator.getRecordLength();
0494 spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
0495 }
0496 spillWriter.close();
0497 }
0498
0499
0500
0501
0502 class SpillableIterator extends UnsafeSorterIterator {
0503 private UnsafeSorterIterator upstream;
0504 private UnsafeSorterIterator nextUpstream = null;
0505 private MemoryBlock lastPage = null;
0506 private boolean loaded = false;
0507 private int numRecords = 0;
0508
0509 SpillableIterator(UnsafeSorterIterator inMemIterator) {
0510 this.upstream = inMemIterator;
0511 this.numRecords = inMemIterator.getNumRecords();
0512 }
0513
0514 @Override
0515 public int getNumRecords() {
0516 return numRecords;
0517 }
0518
0519 public long spill() throws IOException {
0520 synchronized (this) {
0521 if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null
0522 && numRecords > 0)) {
0523 return 0L;
0524 }
0525
0526 UnsafeInMemorySorter.SortedIterator inMemIterator =
0527 ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
0528
0529 ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
0530
0531 final UnsafeSorterSpillWriter spillWriter =
0532 new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
0533 spillIterator(inMemIterator, spillWriter);
0534 spillWriters.add(spillWriter);
0535 nextUpstream = spillWriter.getReader(serializerManager);
0536
0537 long released = 0L;
0538 synchronized (UnsafeExternalSorter.this) {
0539
0540
0541
0542 for (MemoryBlock page : allocatedPages) {
0543 if (!loaded || page.pageNumber !=
0544 ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) {
0545 released += page.size();
0546 freePage(page);
0547 } else {
0548 lastPage = page;
0549 }
0550 }
0551 allocatedPages.clear();
0552 }
0553
0554
0555 assert(inMemSorter != null);
0556 released += inMemSorter.getMemoryUsage();
0557 totalSortTimeNanos += inMemSorter.getSortTimeNanos();
0558 inMemSorter.free();
0559 inMemSorter = null;
0560 taskContext.taskMetrics().incMemoryBytesSpilled(released);
0561 taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten());
0562 totalSpillBytes += released;
0563 return released;
0564 }
0565 }
0566
0567 @Override
0568 public boolean hasNext() {
0569 return numRecords > 0;
0570 }
0571
0572 @Override
0573 public void loadNext() throws IOException {
0574 MemoryBlock pageToFree = null;
0575 try {
0576 synchronized (this) {
0577 loaded = true;
0578 if (nextUpstream != null) {
0579
0580 if(lastPage != null) {
0581
0582
0583
0584
0585
0586 pageToFree = lastPage;
0587 lastPage = null;
0588 }
0589 upstream = nextUpstream;
0590 nextUpstream = null;
0591 }
0592 numRecords--;
0593 upstream.loadNext();
0594 }
0595 } finally {
0596 if (pageToFree != null) {
0597 freePage(pageToFree);
0598 }
0599 }
0600 }
0601
0602 @Override
0603 public Object getBaseObject() {
0604 return upstream.getBaseObject();
0605 }
0606
0607 @Override
0608 public long getBaseOffset() {
0609 return upstream.getBaseOffset();
0610 }
0611
0612 @Override
0613 public int getRecordLength() {
0614 return upstream.getRecordLength();
0615 }
0616
0617 @Override
0618 public long getKeyPrefix() {
0619 return upstream.getKeyPrefix();
0620 }
0621 }
0622
0623
0624
0625
0626
0627
0628
0629
0630
0631
0632 public UnsafeSorterIterator getIterator(int startIndex) throws IOException {
0633 if (spillWriters.isEmpty()) {
0634 assert(inMemSorter != null);
0635 UnsafeSorterIterator iter = inMemSorter.getSortedIterator();
0636 moveOver(iter, startIndex);
0637 return iter;
0638 } else {
0639 LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
0640 int i = 0;
0641 for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
0642 if (i + spillWriter.recordsSpilled() > startIndex) {
0643 UnsafeSorterIterator iter = spillWriter.getReader(serializerManager);
0644 moveOver(iter, startIndex - i);
0645 queue.add(iter);
0646 }
0647 i += spillWriter.recordsSpilled();
0648 }
0649 if (inMemSorter != null) {
0650 UnsafeSorterIterator iter = inMemSorter.getSortedIterator();
0651 moveOver(iter, startIndex - i);
0652 queue.add(iter);
0653 }
0654 return new ChainedIterator(queue);
0655 }
0656 }
0657
0658 private void moveOver(UnsafeSorterIterator iter, int steps)
0659 throws IOException {
0660 if (steps > 0) {
0661 for (int i = 0; i < steps; i++) {
0662 if (iter.hasNext()) {
0663 iter.loadNext();
0664 } else {
0665 throw new ArrayIndexOutOfBoundsException("Failed to move the iterator " + steps +
0666 " steps forward");
0667 }
0668 }
0669 }
0670 }
0671
0672
0673
0674
0675 static class ChainedIterator extends UnsafeSorterIterator {
0676
0677 private final Queue<UnsafeSorterIterator> iterators;
0678 private UnsafeSorterIterator current;
0679 private int numRecords;
0680
0681 ChainedIterator(Queue<UnsafeSorterIterator> iterators) {
0682 assert iterators.size() > 0;
0683 this.numRecords = 0;
0684 for (UnsafeSorterIterator iter: iterators) {
0685 this.numRecords += iter.getNumRecords();
0686 }
0687 this.iterators = iterators;
0688 this.current = iterators.remove();
0689 }
0690
0691 @Override
0692 public int getNumRecords() {
0693 return numRecords;
0694 }
0695
0696 @Override
0697 public boolean hasNext() {
0698 while (!current.hasNext() && !iterators.isEmpty()) {
0699 current = iterators.remove();
0700 }
0701 return current.hasNext();
0702 }
0703
0704 @Override
0705 public void loadNext() throws IOException {
0706 while (!current.hasNext() && !iterators.isEmpty()) {
0707 current = iterators.remove();
0708 }
0709 current.loadNext();
0710 }
0711
0712 @Override
0713 public Object getBaseObject() { return current.getBaseObject(); }
0714
0715 @Override
0716 public long getBaseOffset() { return current.getBaseOffset(); }
0717
0718 @Override
0719 public int getRecordLength() { return current.getRecordLength(); }
0720
0721 @Override
0722 public long getKeyPrefix() { return current.getKeyPrefix(); }
0723 }
0724 }