0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.shuffle.sort;
0019
0020 import java.io.*;
0021 import java.nio.ByteBuffer;
0022 import java.nio.file.Files;
0023 import java.util.*;
0024
0025 import org.mockito.stubbing.Answer;
0026 import scala.Option;
0027 import scala.Product2;
0028 import scala.Tuple2;
0029 import scala.Tuple2$;
0030 import scala.collection.Iterator;
0031
0032 import com.google.common.collect.HashMultiset;
0033 import org.junit.After;
0034 import org.junit.Before;
0035 import org.junit.Test;
0036 import org.mockito.Mock;
0037 import org.mockito.MockitoAnnotations;
0038
0039 import org.apache.spark.HashPartitioner;
0040 import org.apache.spark.ShuffleDependency;
0041 import org.apache.spark.SparkConf;
0042 import org.apache.spark.TaskContext;
0043 import org.apache.spark.executor.ShuffleWriteMetrics;
0044 import org.apache.spark.executor.TaskMetrics;
0045 import org.apache.spark.io.CompressionCodec$;
0046 import org.apache.spark.io.LZ4CompressionCodec;
0047 import org.apache.spark.io.LZFCompressionCodec;
0048 import org.apache.spark.io.SnappyCompressionCodec;
0049 import org.apache.spark.internal.config.package$;
0050 import org.apache.spark.memory.TaskMemoryManager;
0051 import org.apache.spark.memory.TestMemoryManager;
0052 import org.apache.spark.network.util.LimitedInputStream;
0053 import org.apache.spark.scheduler.MapStatus;
0054 import org.apache.spark.security.CryptoStreamUtils;
0055 import org.apache.spark.serializer.*;
0056 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
0057 import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents;
0058 import org.apache.spark.storage.*;
0059 import org.apache.spark.util.Utils;
0060
0061 import static org.hamcrest.MatcherAssert.assertThat;
0062 import static org.hamcrest.Matchers.greaterThan;
0063 import static org.hamcrest.Matchers.lessThan;
0064 import static org.junit.Assert.*;
0065 import static org.mockito.Answers.RETURNS_SMART_NULLS;
0066 import static org.mockito.Mockito.*;
0067
0068 public class UnsafeShuffleWriterSuite {
0069
0070 static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
0071 static final int NUM_PARTITITONS = 4;
0072 TestMemoryManager memoryManager;
0073 TaskMemoryManager taskMemoryManager;
0074 final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
0075 File mergedOutputFile;
0076 File tempDir;
0077 long[] partitionSizesInMergedFile;
0078 final LinkedList<File> spillFilesCreated = new LinkedList<>();
0079 SparkConf conf;
0080 final Serializer serializer = new KryoSerializer(new SparkConf());
0081 TaskMetrics taskMetrics;
0082
0083 @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
0084 @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
0085 @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
0086 @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
0087 @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
0088
0089 @After
0090 public void tearDown() {
0091 Utils.deleteRecursively(tempDir);
0092 final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
0093 if (leakedMemory != 0) {
0094 fail("Test leaked " + leakedMemory + " bytes of managed memory");
0095 }
0096 }
0097
0098 @Before
0099 @SuppressWarnings("unchecked")
0100 public void setUp() throws IOException {
0101 MockitoAnnotations.initMocks(this);
0102 tempDir = Utils.createTempDir(null, "test");
0103 mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
0104 partitionSizesInMergedFile = null;
0105 spillFilesCreated.clear();
0106 conf = new SparkConf()
0107 .set(package$.MODULE$.BUFFER_PAGESIZE().key(), "1m")
0108 .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false);
0109 taskMetrics = new TaskMetrics();
0110 memoryManager = new TestMemoryManager(conf);
0111 taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0112
0113
0114
0115 SerializerManager manager = new SerializerManager(serializer, conf);
0116 when(blockManager.serializerManager()).thenReturn(manager);
0117
0118 when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
0119 when(blockManager.getDiskWriter(
0120 any(BlockId.class),
0121 any(File.class),
0122 any(SerializerInstance.class),
0123 anyInt(),
0124 any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> {
0125 Object[] args = invocationOnMock.getArguments();
0126 return new DiskBlockObjectWriter(
0127 (File) args[1],
0128 blockManager.serializerManager(),
0129 (SerializerInstance) args[2],
0130 (Integer) args[3],
0131 false,
0132 (ShuffleWriteMetrics) args[4],
0133 (BlockId) args[0]
0134 );
0135 });
0136
0137 when(shuffleBlockResolver.getDataFile(anyInt(), anyLong())).thenReturn(mergedOutputFile);
0138
0139 Answer<?> renameTempAnswer = invocationOnMock -> {
0140 partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
0141 File tmp = (File) invocationOnMock.getArguments()[3];
0142 if (!mergedOutputFile.delete()) {
0143 throw new RuntimeException("Failed to delete old merged output file.");
0144 }
0145 if (tmp != null) {
0146 Files.move(tmp.toPath(), mergedOutputFile.toPath());
0147 } else if (!mergedOutputFile.createNewFile()) {
0148 throw new RuntimeException("Failed to create empty merged output file.");
0149 }
0150 return null;
0151 };
0152
0153 doAnswer(renameTempAnswer)
0154 .when(shuffleBlockResolver)
0155 .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class));
0156
0157 doAnswer(renameTempAnswer)
0158 .when(shuffleBlockResolver)
0159 .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null));
0160
0161 when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
0162 TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
0163 File file = File.createTempFile("spillFile", ".spill", tempDir);
0164 spillFilesCreated.add(file);
0165 return Tuple2$.MODULE$.apply(blockId, file);
0166 });
0167
0168 when(taskContext.taskMetrics()).thenReturn(taskMetrics);
0169 when(shuffleDep.serializer()).thenReturn(serializer);
0170 when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
0171 when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
0172 }
0173
0174 private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
0175 conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
0176 return new UnsafeShuffleWriter<>(
0177 blockManager,
0178 taskMemoryManager,
0179 new SerializedShuffleHandle<>(0, shuffleDep),
0180 0L,
0181 taskContext,
0182 conf,
0183 taskContext.taskMetrics().shuffleWriteMetrics(),
0184 new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
0185 }
0186
0187 private void assertSpillFilesWereCleanedUp() {
0188 for (File spillFile : spillFilesCreated) {
0189 assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
0190 spillFile.exists());
0191 }
0192 }
0193
0194 private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
0195 final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<>();
0196 long startOffset = 0;
0197 for (int i = 0; i < NUM_PARTITITONS; i++) {
0198 final long partitionSize = partitionSizesInMergedFile[i];
0199 if (partitionSize > 0) {
0200 FileInputStream fin = new FileInputStream(mergedOutputFile);
0201 fin.getChannel().position(startOffset);
0202 InputStream in = new LimitedInputStream(fin, partitionSize);
0203 in = blockManager.serializerManager().wrapForEncryption(in);
0204 if ((boolean) conf.get(package$.MODULE$.SHUFFLE_COMPRESS())) {
0205 in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
0206 }
0207 try (DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in)) {
0208 Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
0209 while (records.hasNext()) {
0210 Tuple2<Object, Object> record = records.next();
0211 assertEquals(i, hashPartitioner.getPartition(record._1()));
0212 recordsList.add(record);
0213 }
0214 }
0215 startOffset += partitionSize;
0216 }
0217 }
0218 return recordsList;
0219 }
0220
0221 @Test(expected=IllegalStateException.class)
0222 public void mustCallWriteBeforeSuccessfulStop() throws IOException {
0223 createWriter(false).stop(true);
0224 }
0225
0226 @Test
0227 public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
0228 createWriter(false).stop(false);
0229 }
0230
0231 static class PandaException extends RuntimeException {
0232 }
0233
0234 @Test(expected=PandaException.class)
0235 public void writeFailurePropagates() throws Exception {
0236 class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> {
0237 @Override public boolean hasNext() {
0238 throw new PandaException();
0239 }
0240 @Override public Product2<Object, Object> next() {
0241 return null;
0242 }
0243 }
0244 final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
0245 writer.write(new BadRecords());
0246 }
0247
0248 @Test
0249 public void writeEmptyIterator() throws Exception {
0250 final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
0251 writer.write(new ArrayList<Product2<Object, Object>>().iterator());
0252 final Option<MapStatus> mapStatus = writer.stop(true);
0253 assertTrue(mapStatus.isDefined());
0254 assertTrue(mergedOutputFile.exists());
0255 assertEquals(0, spillFilesCreated.size());
0256 assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
0257 assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten());
0258 assertEquals(0, taskMetrics.shuffleWriteMetrics().bytesWritten());
0259 assertEquals(0, taskMetrics.diskBytesSpilled());
0260 assertEquals(0, taskMetrics.memoryBytesSpilled());
0261 }
0262
0263 @Test
0264 public void writeWithoutSpilling() throws Exception {
0265
0266 final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0267 for (int i = 0; i < NUM_PARTITITONS; i++) {
0268 dataToWrite.add(new Tuple2<>(i, i));
0269 }
0270 final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
0271 writer.write(dataToWrite.iterator());
0272 final Option<MapStatus> mapStatus = writer.stop(true);
0273 assertTrue(mapStatus.isDefined());
0274 assertTrue(mergedOutputFile.exists());
0275
0276 long sumOfPartitionSizes = 0;
0277 for (long size: partitionSizesInMergedFile) {
0278
0279 assertEquals(partitionSizesInMergedFile[0], size);
0280 sumOfPartitionSizes += size;
0281 }
0282 assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
0283 assertEquals(
0284 HashMultiset.create(dataToWrite),
0285 HashMultiset.create(readRecordsFromFile()));
0286 assertSpillFilesWereCleanedUp();
0287 ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0288 assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0289 assertEquals(0, taskMetrics.diskBytesSpilled());
0290 assertEquals(0, taskMetrics.memoryBytesSpilled());
0291 assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0292 }
0293
0294 private void testMergingSpills(
0295 final boolean transferToEnabled,
0296 String compressionCodecName,
0297 boolean encrypt) throws Exception {
0298 if (compressionCodecName != null) {
0299 conf.set(package$.MODULE$.SHUFFLE_COMPRESS(), true);
0300 conf.set("spark.io.compression.codec", compressionCodecName);
0301 } else {
0302 conf.set(package$.MODULE$.SHUFFLE_COMPRESS(), false);
0303 }
0304 conf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);
0305
0306 SerializerManager manager;
0307 if (encrypt) {
0308 manager = new SerializerManager(serializer, conf,
0309 Option.apply(CryptoStreamUtils.createKey(conf)));
0310 } else {
0311 manager = new SerializerManager(serializer, conf);
0312 }
0313
0314 when(blockManager.serializerManager()).thenReturn(manager);
0315 testMergingSpills(transferToEnabled, encrypt);
0316 }
0317
0318 private void testMergingSpills(
0319 boolean transferToEnabled,
0320 boolean encrypted) throws IOException {
0321 final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
0322 final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0323 for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
0324 dataToWrite.add(new Tuple2<>(i, i));
0325 }
0326 writer.insertRecordIntoSorter(dataToWrite.get(0));
0327 writer.insertRecordIntoSorter(dataToWrite.get(1));
0328 writer.insertRecordIntoSorter(dataToWrite.get(2));
0329 writer.insertRecordIntoSorter(dataToWrite.get(3));
0330 writer.forceSorterToSpill();
0331 writer.insertRecordIntoSorter(dataToWrite.get(4));
0332 writer.insertRecordIntoSorter(dataToWrite.get(5));
0333 writer.closeAndWriteOutput();
0334 final Option<MapStatus> mapStatus = writer.stop(true);
0335 assertTrue(mapStatus.isDefined());
0336 assertTrue(mergedOutputFile.exists());
0337 assertEquals(2, spillFilesCreated.size());
0338
0339 long sumOfPartitionSizes = 0;
0340 for (long size: partitionSizesInMergedFile) {
0341 sumOfPartitionSizes += size;
0342 }
0343
0344 assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
0345
0346 assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
0347 assertSpillFilesWereCleanedUp();
0348 ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0349 assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0350 assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
0351 assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
0352 assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
0353 assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0354 }
0355
0356 @Test
0357 public void mergeSpillsWithTransferToAndLZF() throws Exception {
0358 testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
0359 }
0360
0361 @Test
0362 public void mergeSpillsWithFileStreamAndLZF() throws Exception {
0363 testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
0364 }
0365
0366 @Test
0367 public void mergeSpillsWithTransferToAndLZ4() throws Exception {
0368 testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
0369 }
0370
0371 @Test
0372 public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
0373 testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
0374 }
0375
0376 @Test
0377 public void mergeSpillsWithTransferToAndSnappy() throws Exception {
0378 testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
0379 }
0380
0381 @Test
0382 public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
0383 testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
0384 }
0385
0386 @Test
0387 public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
0388 testMergingSpills(true, null, false);
0389 }
0390
0391 @Test
0392 public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
0393 testMergingSpills(false, null, false);
0394 }
0395
0396 @Test
0397 public void mergeSpillsWithCompressionAndEncryption() throws Exception {
0398
0399
0400 testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
0401 }
0402
0403 @Test
0404 public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
0405 testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
0406 }
0407
0408 @Test
0409 public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
0410 conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false);
0411 testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
0412 }
0413
0414 @Test
0415 public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
0416
0417
0418 testMergingSpills(true, null, true);
0419 }
0420
0421 @Test
0422 public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception {
0423 testMergingSpills(false, null, true);
0424 }
0425
0426 @Test
0427 public void writeEnoughDataToTriggerSpill() throws Exception {
0428 memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES);
0429 final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0430 final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0431 final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10];
0432 for (int i = 0; i < 10 + 1; i++) {
0433 dataToWrite.add(new Tuple2<>(i, bigByteArray));
0434 }
0435 writer.write(dataToWrite.iterator());
0436 assertEquals(2, spillFilesCreated.size());
0437 writer.stop(true);
0438 readRecordsFromFile();
0439 assertSpillFilesWereCleanedUp();
0440 ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0441 assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0442 assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
0443 assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
0444 assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
0445 assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0446 }
0447
0448 @Test
0449 public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOff() throws Exception {
0450 conf.set(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT(), false);
0451 writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
0452 assertEquals(2, spillFilesCreated.size());
0453 }
0454
0455 @Test
0456 public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() throws Exception {
0457 conf.set(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT(), true);
0458 writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
0459 assertEquals(3, spillFilesCreated.size());
0460 }
0461
0462 private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
0463 memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
0464 final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0465 final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0466 for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
0467 dataToWrite.add(new Tuple2<>(i, i));
0468 }
0469 writer.write(dataToWrite.iterator());
0470 writer.stop(true);
0471 readRecordsFromFile();
0472 assertSpillFilesWereCleanedUp();
0473 ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0474 assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0475 assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
0476 assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
0477 assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
0478 assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0479 }
0480
0481 @Test
0482 public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
0483 final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0484 final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0485 final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
0486 new Random(42).nextBytes(bytes);
0487 dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(bytes)));
0488 writer.write(dataToWrite.iterator());
0489 writer.stop(true);
0490 assertEquals(
0491 HashMultiset.create(dataToWrite),
0492 HashMultiset.create(readRecordsFromFile()));
0493 assertSpillFilesWereCleanedUp();
0494 }
0495
0496 @Test
0497 public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
0498 final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0499 final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0500 dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(new byte[1])));
0501
0502 final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4];
0503 new Random(42).nextBytes(atMaxRecordSize);
0504 dataToWrite.add(new Tuple2<>(2, ByteBuffer.wrap(atMaxRecordSize)));
0505
0506 final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()];
0507 new Random(42).nextBytes(exceedsMaxRecordSize);
0508 dataToWrite.add(new Tuple2<>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
0509 writer.write(dataToWrite.iterator());
0510 writer.stop(true);
0511 assertEquals(
0512 HashMultiset.create(dataToWrite),
0513 HashMultiset.create(readRecordsFromFile()));
0514 assertSpillFilesWereCleanedUp();
0515 }
0516
0517 @Test
0518 public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
0519 final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0520 writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
0521 writer.insertRecordIntoSorter(new Tuple2<>(2, 2));
0522 writer.forceSorterToSpill();
0523 writer.insertRecordIntoSorter(new Tuple2<>(2, 2));
0524 writer.stop(false);
0525 assertSpillFilesWereCleanedUp();
0526 }
0527
0528 @Test
0529 public void testPeakMemoryUsed() throws Exception {
0530 final long recordLengthBytes = 8;
0531 final long pageSizeBytes = 256;
0532 final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
0533 taskMemoryManager = spy(taskMemoryManager);
0534 when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
0535 final UnsafeShuffleWriter<Object, Object> writer = new UnsafeShuffleWriter<>(
0536 blockManager,
0537 taskMemoryManager,
0538 new SerializedShuffleHandle<>(0, shuffleDep),
0539 0L,
0540 taskContext,
0541 conf,
0542 taskContext.taskMetrics().shuffleWriteMetrics(),
0543 new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
0544
0545
0546
0547 long previousPeakMemory = writer.getPeakMemoryUsedBytes();
0548 long newPeakMemory;
0549 try {
0550 for (int i = 0; i < numRecordsPerPage * 10; i++) {
0551 writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
0552 newPeakMemory = writer.getPeakMemoryUsedBytes();
0553 if (i % numRecordsPerPage == 0) {
0554
0555
0556 assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
0557 } else {
0558 assertEquals(previousPeakMemory, newPeakMemory);
0559 }
0560 previousPeakMemory = newPeakMemory;
0561 }
0562
0563
0564 writer.forceSorterToSpill();
0565 newPeakMemory = writer.getPeakMemoryUsedBytes();
0566 assertEquals(previousPeakMemory, newPeakMemory);
0567 for (int i = 0; i < numRecordsPerPage; i++) {
0568 writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
0569 }
0570 newPeakMemory = writer.getPeakMemoryUsedBytes();
0571 assertEquals(previousPeakMemory, newPeakMemory);
0572
0573
0574 writer.closeAndWriteOutput();
0575 newPeakMemory = writer.getPeakMemoryUsedBytes();
0576 assertEquals(previousPeakMemory, newPeakMemory);
0577 } finally {
0578 writer.stop(false);
0579 }
0580 }
0581
0582 }