0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.util.collection.unsafe.sort;
0019
0020 import java.io.File;
0021 import java.io.IOException;
0022 import java.util.Arrays;
0023 import java.util.LinkedList;
0024 import java.util.UUID;
0025
0026 import org.hamcrest.Matchers;
0027 import scala.Tuple2$;
0028
0029 import org.junit.After;
0030 import org.junit.Before;
0031 import org.junit.Test;
0032 import org.mockito.Mock;
0033 import org.mockito.MockitoAnnotations;
0034
0035 import org.apache.spark.SparkConf;
0036 import org.apache.spark.TaskContext;
0037 import org.apache.spark.executor.ShuffleWriteMetrics;
0038 import org.apache.spark.executor.TaskMetrics;
0039 import org.apache.spark.internal.config.package$;
0040 import org.apache.spark.memory.TestMemoryManager;
0041 import org.apache.spark.memory.SparkOutOfMemoryError;
0042 import org.apache.spark.memory.TaskMemoryManager;
0043 import org.apache.spark.serializer.JavaSerializer;
0044 import org.apache.spark.serializer.SerializerInstance;
0045 import org.apache.spark.serializer.SerializerManager;
0046 import org.apache.spark.storage.*;
0047 import org.apache.spark.unsafe.Platform;
0048 import org.apache.spark.util.Utils;
0049
0050 import static org.hamcrest.Matchers.greaterThan;
0051 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
0052 import static org.junit.Assert.*;
0053 import static org.mockito.Answers.RETURNS_SMART_NULLS;
0054 import static org.mockito.Mockito.*;
0055
0056 public class UnsafeExternalSorterSuite {
0057
0058 private final SparkConf conf = new SparkConf();
0059
0060 final LinkedList<File> spillFilesCreated = new LinkedList<>();
0061 final TestMemoryManager memoryManager =
0062 new TestMemoryManager(conf.clone().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false));
0063 final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0064 final SerializerManager serializerManager = new SerializerManager(
0065 new JavaSerializer(conf),
0066 conf.clone().set(package$.MODULE$.SHUFFLE_SPILL_COMPRESS(), false));
0067
0068 final PrefixComparator prefixComparator = PrefixComparators.LONG;
0069
0070
0071 final RecordComparator recordComparator = new RecordComparator() {
0072 @Override
0073 public int compare(
0074 Object leftBaseObject,
0075 long leftBaseOffset,
0076 int leftBaseLength,
0077 Object rightBaseObject,
0078 long rightBaseOffset,
0079 int rightBaseLength) {
0080 return 0;
0081 }
0082 };
0083
0084 File tempDir;
0085 @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
0086 @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
0087 @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
0088
0089 protected boolean shouldUseRadixSort() { return false; }
0090
0091 private final long pageSizeBytes = conf.getSizeAsBytes(
0092 package$.MODULE$.BUFFER_PAGESIZE().key(), "4m");
0093
0094 private final int spillThreshold =
0095 (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD());
0096
0097 @Before
0098 public void setUp() {
0099 MockitoAnnotations.initMocks(this);
0100 tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
0101 spillFilesCreated.clear();
0102 taskContext = mock(TaskContext.class);
0103 when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
0104 when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
0105 when(diskBlockManager.createTempLocalBlock()).thenAnswer(invocationOnMock -> {
0106 TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
0107 File file = File.createTempFile("spillFile", ".spill", tempDir);
0108 spillFilesCreated.add(file);
0109 return Tuple2$.MODULE$.apply(blockId, file);
0110 });
0111 when(blockManager.getDiskWriter(
0112 any(BlockId.class),
0113 any(File.class),
0114 any(SerializerInstance.class),
0115 anyInt(),
0116 any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> {
0117 Object[] args = invocationOnMock.getArguments();
0118
0119 return new DiskBlockObjectWriter(
0120 (File) args[1],
0121 serializerManager,
0122 (SerializerInstance) args[2],
0123 (Integer) args[3],
0124 false,
0125 (ShuffleWriteMetrics) args[4],
0126 (BlockId) args[0]
0127 );
0128 });
0129 }
0130
0131 @After
0132 public void tearDown() {
0133 try {
0134 assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
0135 } finally {
0136 Utils.deleteRecursively(tempDir);
0137 tempDir = null;
0138 }
0139 }
0140
0141 private void assertSpillFilesWereCleanedUp() {
0142 for (File spillFile : spillFilesCreated) {
0143 assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
0144 spillFile.exists());
0145 }
0146 }
0147
0148 private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
0149 final int[] arr = new int[]{ value };
0150 sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value, false);
0151 }
0152
0153 private static void insertRecord(
0154 UnsafeExternalSorter sorter,
0155 int[] record,
0156 long prefix) throws IOException {
0157 sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix, false);
0158 }
0159
0160 private UnsafeExternalSorter newSorter() throws IOException {
0161 return UnsafeExternalSorter.create(
0162 taskMemoryManager,
0163 blockManager,
0164 serializerManager,
0165 taskContext,
0166 () -> recordComparator,
0167 prefixComparator,
0168 1024,
0169 pageSizeBytes,
0170 spillThreshold,
0171 shouldUseRadixSort());
0172 }
0173
0174 @Test
0175 public void testSortingOnlyByPrefix() throws Exception {
0176 final UnsafeExternalSorter sorter = newSorter();
0177 insertNumber(sorter, 5);
0178 insertNumber(sorter, 1);
0179 insertNumber(sorter, 3);
0180 sorter.spill();
0181 insertNumber(sorter, 4);
0182 sorter.spill();
0183 insertNumber(sorter, 2);
0184
0185 UnsafeSorterIterator iter = sorter.getSortedIterator();
0186
0187 for (int i = 1; i <= 5; i++) {
0188 iter.loadNext();
0189 assertEquals(i, iter.getKeyPrefix());
0190 assertEquals(4, iter.getRecordLength());
0191 assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
0192 }
0193
0194 sorter.cleanupResources();
0195 assertSpillFilesWereCleanedUp();
0196 }
0197
0198 @Test
0199 public void testSortingEmptyArrays() throws Exception {
0200 final UnsafeExternalSorter sorter = newSorter();
0201 sorter.insertRecord(null, 0, 0, 0, false);
0202 sorter.insertRecord(null, 0, 0, 0, false);
0203 sorter.spill();
0204 sorter.insertRecord(null, 0, 0, 0, false);
0205 sorter.spill();
0206 sorter.insertRecord(null, 0, 0, 0, false);
0207 sorter.insertRecord(null, 0, 0, 0, false);
0208
0209 UnsafeSorterIterator iter = sorter.getSortedIterator();
0210
0211 for (int i = 1; i <= 5; i++) {
0212 iter.loadNext();
0213 assertEquals(0, iter.getKeyPrefix());
0214 assertEquals(0, iter.getRecordLength());
0215 }
0216
0217 sorter.cleanupResources();
0218 assertSpillFilesWereCleanedUp();
0219 }
0220
0221 @Test
0222 public void testSortTimeMetric() throws Exception {
0223 final UnsafeExternalSorter sorter = newSorter();
0224 long prevSortTime = sorter.getSortTimeNanos();
0225 assertEquals(0, prevSortTime);
0226
0227 sorter.insertRecord(null, 0, 0, 0, false);
0228 sorter.spill();
0229 assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
0230 prevSortTime = sorter.getSortTimeNanos();
0231
0232 sorter.spill();
0233 assertEquals(prevSortTime, sorter.getSortTimeNanos());
0234
0235 sorter.insertRecord(null, 0, 0, 0, false);
0236 UnsafeSorterIterator iter = sorter.getSortedIterator();
0237 assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
0238
0239 sorter.cleanupResources();
0240 assertSpillFilesWereCleanedUp();
0241 }
0242
0243 @Test
0244 public void spillingOccursInResponseToMemoryPressure() throws Exception {
0245 final UnsafeExternalSorter sorter = newSorter();
0246
0247 final int numRecords = (int) (pageSizeBytes / (4 + 4));
0248 for (int i = 0; i < numRecords; i++) {
0249 insertNumber(sorter, numRecords - i);
0250 }
0251 assertEquals(1, sorter.getNumberOfAllocatedPages());
0252 memoryManager.markExecutionAsOutOfMemoryOnce();
0253
0254 insertNumber(sorter, 0);
0255
0256 assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1));
0257
0258 UnsafeSorterIterator iter = sorter.getSortedIterator();
0259
0260 int i = 0;
0261 while (iter.hasNext()) {
0262 iter.loadNext();
0263 assertEquals(i, iter.getKeyPrefix());
0264 assertEquals(4, iter.getRecordLength());
0265 assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
0266 i++;
0267 }
0268 assertEquals(numRecords + 1, i);
0269 sorter.cleanupResources();
0270 assertSpillFilesWereCleanedUp();
0271 }
0272
0273 @Test
0274 public void testFillingPage() throws Exception {
0275 final UnsafeExternalSorter sorter = newSorter();
0276 byte[] record = new byte[16];
0277 while (sorter.getNumberOfAllocatedPages() < 2) {
0278 sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0, false);
0279 }
0280 sorter.cleanupResources();
0281 assertSpillFilesWereCleanedUp();
0282 }
0283
0284 @Test
0285 public void sortingRecordsThatExceedPageSize() throws Exception {
0286 final UnsafeExternalSorter sorter = newSorter();
0287 final int[] largeRecord = new int[(int) pageSizeBytes + 16];
0288 Arrays.fill(largeRecord, 456);
0289 final int[] smallRecord = new int[100];
0290 Arrays.fill(smallRecord, 123);
0291
0292 insertRecord(sorter, largeRecord, 456);
0293 sorter.spill();
0294 insertRecord(sorter, smallRecord, 123);
0295 sorter.spill();
0296 insertRecord(sorter, smallRecord, 123);
0297 insertRecord(sorter, largeRecord, 456);
0298
0299 UnsafeSorterIterator iter = sorter.getSortedIterator();
0300
0301 assertTrue(iter.hasNext());
0302 iter.loadNext();
0303 assertEquals(123, iter.getKeyPrefix());
0304 assertEquals(smallRecord.length * 4, iter.getRecordLength());
0305 assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
0306
0307 assertTrue(iter.hasNext());
0308 iter.loadNext();
0309 assertEquals(123, iter.getKeyPrefix());
0310 assertEquals(smallRecord.length * 4, iter.getRecordLength());
0311 assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
0312
0313 assertTrue(iter.hasNext());
0314 iter.loadNext();
0315 assertEquals(456, iter.getKeyPrefix());
0316 assertEquals(largeRecord.length * 4, iter.getRecordLength());
0317 assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
0318
0319 assertTrue(iter.hasNext());
0320 iter.loadNext();
0321 assertEquals(456, iter.getKeyPrefix());
0322 assertEquals(largeRecord.length * 4, iter.getRecordLength());
0323 assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
0324
0325 assertFalse(iter.hasNext());
0326 sorter.cleanupResources();
0327 assertSpillFilesWereCleanedUp();
0328 }
0329
0330 @Test
0331 public void forcedSpillingWithReadIterator() throws Exception {
0332 final UnsafeExternalSorter sorter = newSorter();
0333 long[] record = new long[100];
0334 int recordSize = record.length * 8;
0335 int n = (int) pageSizeBytes / recordSize * 3;
0336 for (int i = 0; i < n; i++) {
0337 record[0] = (long) i;
0338 sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
0339 }
0340 assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
0341 UnsafeExternalSorter.SpillableIterator iter =
0342 (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
0343 int lastv = 0;
0344 for (int i = 0; i < n / 3; i++) {
0345 iter.hasNext();
0346 iter.loadNext();
0347 assertTrue(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
0348 lastv = i;
0349 }
0350 assertTrue(iter.spill() > 0);
0351 assertEquals(0, iter.spill());
0352 assertTrue(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv);
0353 for (int i = n / 3; i < n; i++) {
0354 iter.hasNext();
0355 iter.loadNext();
0356 assertEquals(i, Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()));
0357 }
0358 sorter.cleanupResources();
0359 assertSpillFilesWereCleanedUp();
0360 }
0361
0362 @Test
0363 public void forcedSpillingWithNotReadIterator() throws Exception {
0364 final UnsafeExternalSorter sorter = newSorter();
0365 long[] record = new long[100];
0366 int recordSize = record.length * 8;
0367 int n = (int) pageSizeBytes / recordSize * 3;
0368 for (int i = 0; i < n; i++) {
0369 record[0] = (long) i;
0370 sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
0371 }
0372 assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
0373 UnsafeExternalSorter.SpillableIterator iter =
0374 (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
0375 assertTrue(iter.spill() > 0);
0376 assertEquals(0, iter.spill());
0377 for (int i = 0; i < n; i++) {
0378 iter.hasNext();
0379 iter.loadNext();
0380 assertEquals(i, Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()));
0381 }
0382 sorter.cleanupResources();
0383 assertSpillFilesWereCleanedUp();
0384 }
0385
0386 @Test
0387 public void forcedSpillingWithoutComparator() throws Exception {
0388 final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
0389 taskMemoryManager,
0390 blockManager,
0391 serializerManager,
0392 taskContext,
0393 null,
0394 null,
0395 1024,
0396 pageSizeBytes,
0397 spillThreshold,
0398 shouldUseRadixSort());
0399 long[] record = new long[100];
0400 int recordSize = record.length * 8;
0401 int n = (int) pageSizeBytes / recordSize * 3;
0402 int batch = n / 4;
0403 for (int i = 0; i < n; i++) {
0404 record[0] = (long) i;
0405 sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
0406 if (i % batch == batch - 1) {
0407 sorter.spill();
0408 }
0409 }
0410 UnsafeSorterIterator iter = sorter.getIterator(0);
0411 for (int i = 0; i < n; i++) {
0412 iter.hasNext();
0413 iter.loadNext();
0414 assertEquals(i, Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()));
0415 }
0416 sorter.cleanupResources();
0417 assertSpillFilesWereCleanedUp();
0418 }
0419
0420 @Test
0421 public void testDiskSpilledBytes() throws Exception {
0422 final UnsafeExternalSorter sorter = newSorter();
0423 long[] record = new long[100];
0424 int recordSize = record.length * 8;
0425 int n = (int) pageSizeBytes / recordSize * 3;
0426 for (int i = 0; i < n; i++) {
0427 record[0] = (long) i;
0428 sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
0429 }
0430
0431
0432 assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
0433 assertTrue(taskContext.taskMetrics().diskBytesSpilled() == 0);
0434 UnsafeExternalSorter.SpillableIterator iter =
0435 (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
0436 assertTrue(iter.spill() > 0);
0437 assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0);
0438 assertEquals(0, iter.spill());
0439
0440 assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0);
0441 sorter.cleanupResources();
0442 assertSpillFilesWereCleanedUp();
0443 }
0444
0445 @Test
0446 public void testPeakMemoryUsed() throws Exception {
0447 final long recordLengthBytes = 8;
0448 final long pageSizeBytes = 256;
0449 final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
0450 final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
0451 taskMemoryManager,
0452 blockManager,
0453 serializerManager,
0454 taskContext,
0455 () -> recordComparator,
0456 prefixComparator,
0457 1024,
0458 pageSizeBytes,
0459 spillThreshold,
0460 shouldUseRadixSort());
0461
0462
0463
0464 long previousPeakMemory = sorter.getPeakMemoryUsedBytes();
0465 long newPeakMemory;
0466 try {
0467 for (int i = 0; i < numRecordsPerPage * 10; i++) {
0468 insertNumber(sorter, i);
0469 newPeakMemory = sorter.getPeakMemoryUsedBytes();
0470 if (i % numRecordsPerPage == 0) {
0471
0472 assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
0473 } else {
0474 assertEquals(previousPeakMemory, newPeakMemory);
0475 }
0476 previousPeakMemory = newPeakMemory;
0477 }
0478
0479
0480 sorter.spill();
0481 newPeakMemory = sorter.getPeakMemoryUsedBytes();
0482 assertEquals(previousPeakMemory, newPeakMemory);
0483 for (int i = 0; i < numRecordsPerPage; i++) {
0484 insertNumber(sorter, i);
0485 }
0486 newPeakMemory = sorter.getPeakMemoryUsedBytes();
0487 assertEquals(previousPeakMemory, newPeakMemory);
0488 } finally {
0489 sorter.cleanupResources();
0490 assertSpillFilesWereCleanedUp();
0491 }
0492 }
0493
0494 @Test
0495 public void testGetIterator() throws Exception {
0496 final UnsafeExternalSorter sorter = newSorter();
0497 for (int i = 0; i < 100; i++) {
0498 insertNumber(sorter, i);
0499 }
0500 verifyIntIterator(sorter.getIterator(0), 0, 100);
0501 verifyIntIterator(sorter.getIterator(79), 79, 100);
0502
0503 sorter.spill();
0504 for (int i = 100; i < 200; i++) {
0505 insertNumber(sorter, i);
0506 }
0507 sorter.spill();
0508 verifyIntIterator(sorter.getIterator(79), 79, 200);
0509
0510 for (int i = 200; i < 300; i++) {
0511 insertNumber(sorter, i);
0512 }
0513 verifyIntIterator(sorter.getIterator(79), 79, 300);
0514 verifyIntIterator(sorter.getIterator(139), 139, 300);
0515 verifyIntIterator(sorter.getIterator(279), 279, 300);
0516 sorter.cleanupResources();
0517 assertSpillFilesWereCleanedUp();
0518 }
0519
0520 @Test
0521 public void testOOMDuringSpill() throws Exception {
0522 final UnsafeExternalSorter sorter = newSorter();
0523
0524
0525
0526
0527
0528
0529 for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) {
0530 insertNumber(sorter, i);
0531 }
0532
0533
0534
0535
0536
0537
0538
0539 memoryManager.markconsequentOOM(2);
0540 try {
0541 insertNumber(sorter, 1024);
0542 fail("expected OutOfMmoryError but it seems operation surprisingly succeeded");
0543 }
0544
0545 catch (SparkOutOfMemoryError oom){
0546 String oomStackTrace = Utils.exceptionString(oom);
0547 assertThat("expected SparkOutOfMemoryError in " +
0548 "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset",
0549 oomStackTrace,
0550 Matchers.containsString(
0551 "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset"));
0552 }
0553 }
0554
0555 private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end)
0556 throws IOException {
0557 for (int i = start; i < end; i++) {
0558 assert (iter.hasNext());
0559 iter.loadNext();
0560 assert (Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()) == i);
0561 }
0562 }
0563 }