Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.shuffle.sort;
0019 
0020 import java.io.*;
0021 import java.nio.ByteBuffer;
0022 import java.nio.file.Files;
0023 import java.util.*;
0024 
0025 import org.mockito.stubbing.Answer;
0026 import scala.Option;
0027 import scala.Product2;
0028 import scala.Tuple2;
0029 import scala.Tuple2$;
0030 import scala.collection.Iterator;
0031 
0032 import com.google.common.collect.HashMultiset;
0033 import org.junit.After;
0034 import org.junit.Before;
0035 import org.junit.Test;
0036 import org.mockito.Mock;
0037 import org.mockito.MockitoAnnotations;
0038 
0039 import org.apache.spark.HashPartitioner;
0040 import org.apache.spark.ShuffleDependency;
0041 import org.apache.spark.SparkConf;
0042 import org.apache.spark.TaskContext;
0043 import org.apache.spark.executor.ShuffleWriteMetrics;
0044 import org.apache.spark.executor.TaskMetrics;
0045 import org.apache.spark.io.CompressionCodec$;
0046 import org.apache.spark.io.LZ4CompressionCodec;
0047 import org.apache.spark.io.LZFCompressionCodec;
0048 import org.apache.spark.io.SnappyCompressionCodec;
0049 import org.apache.spark.internal.config.package$;
0050 import org.apache.spark.memory.TaskMemoryManager;
0051 import org.apache.spark.memory.TestMemoryManager;
0052 import org.apache.spark.network.util.LimitedInputStream;
0053 import org.apache.spark.scheduler.MapStatus;
0054 import org.apache.spark.security.CryptoStreamUtils;
0055 import org.apache.spark.serializer.*;
0056 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
0057 import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents;
0058 import org.apache.spark.storage.*;
0059 import org.apache.spark.util.Utils;
0060 
0061 import static org.hamcrest.MatcherAssert.assertThat;
0062 import static org.hamcrest.Matchers.greaterThan;
0063 import static org.hamcrest.Matchers.lessThan;
0064 import static org.junit.Assert.*;
0065 import static org.mockito.Answers.RETURNS_SMART_NULLS;
0066 import static org.mockito.Mockito.*;
0067 
0068 public class UnsafeShuffleWriterSuite {
0069 
0070   static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
0071   static final int NUM_PARTITITONS = 4;
0072   TestMemoryManager memoryManager;
0073   TaskMemoryManager taskMemoryManager;
0074   final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
0075   File mergedOutputFile;
0076   File tempDir;
0077   long[] partitionSizesInMergedFile;
0078   final LinkedList<File> spillFilesCreated = new LinkedList<>();
0079   SparkConf conf;
0080   final Serializer serializer = new KryoSerializer(new SparkConf());
0081   TaskMetrics taskMetrics;
0082 
0083   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
0084   @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
0085   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
0086   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
0087   @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
0088 
0089   @After
0090   public void tearDown() {
0091     Utils.deleteRecursively(tempDir);
0092     final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
0093     if (leakedMemory != 0) {
0094       fail("Test leaked " + leakedMemory + " bytes of managed memory");
0095     }
0096   }
0097 
0098   @Before
0099   @SuppressWarnings("unchecked")
0100   public void setUp() throws IOException {
0101     MockitoAnnotations.initMocks(this);
0102     tempDir = Utils.createTempDir(null, "test");
0103     mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
0104     partitionSizesInMergedFile = null;
0105     spillFilesCreated.clear();
0106     conf = new SparkConf()
0107       .set(package$.MODULE$.BUFFER_PAGESIZE().key(), "1m")
0108       .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false);
0109     taskMetrics = new TaskMetrics();
0110     memoryManager = new TestMemoryManager(conf);
0111     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0112 
0113     // Some tests will override this manager because they change the configuration. This is a
0114     // default for tests that don't need a specific one.
0115     SerializerManager manager = new SerializerManager(serializer, conf);
0116     when(blockManager.serializerManager()).thenReturn(manager);
0117 
0118     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
0119     when(blockManager.getDiskWriter(
0120       any(BlockId.class),
0121       any(File.class),
0122       any(SerializerInstance.class),
0123       anyInt(),
0124       any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> {
0125         Object[] args = invocationOnMock.getArguments();
0126         return new DiskBlockObjectWriter(
0127           (File) args[1],
0128           blockManager.serializerManager(),
0129           (SerializerInstance) args[2],
0130           (Integer) args[3],
0131           false,
0132           (ShuffleWriteMetrics) args[4],
0133           (BlockId) args[0]
0134         );
0135       });
0136 
0137     when(shuffleBlockResolver.getDataFile(anyInt(), anyLong())).thenReturn(mergedOutputFile);
0138 
0139     Answer<?> renameTempAnswer = invocationOnMock -> {
0140       partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
0141       File tmp = (File) invocationOnMock.getArguments()[3];
0142       if (!mergedOutputFile.delete()) {
0143         throw new RuntimeException("Failed to delete old merged output file.");
0144       }
0145       if (tmp != null) {
0146         Files.move(tmp.toPath(), mergedOutputFile.toPath());
0147       } else if (!mergedOutputFile.createNewFile()) {
0148         throw new RuntimeException("Failed to create empty merged output file.");
0149       }
0150       return null;
0151     };
0152 
0153     doAnswer(renameTempAnswer)
0154         .when(shuffleBlockResolver)
0155         .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class));
0156 
0157     doAnswer(renameTempAnswer)
0158         .when(shuffleBlockResolver)
0159         .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null));
0160 
0161     when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
0162       TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
0163       File file = File.createTempFile("spillFile", ".spill", tempDir);
0164       spillFilesCreated.add(file);
0165       return Tuple2$.MODULE$.apply(blockId, file);
0166     });
0167 
0168     when(taskContext.taskMetrics()).thenReturn(taskMetrics);
0169     when(shuffleDep.serializer()).thenReturn(serializer);
0170     when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
0171     when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
0172   }
0173 
0174   private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
0175     conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
0176     return new UnsafeShuffleWriter<>(
0177       blockManager,
0178       taskMemoryManager,
0179       new SerializedShuffleHandle<>(0, shuffleDep),
0180       0L, // map id
0181       taskContext,
0182       conf,
0183       taskContext.taskMetrics().shuffleWriteMetrics(),
0184       new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
0185   }
0186 
0187   private void assertSpillFilesWereCleanedUp() {
0188     for (File spillFile : spillFilesCreated) {
0189       assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
0190         spillFile.exists());
0191     }
0192   }
0193 
0194   private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
0195     final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<>();
0196     long startOffset = 0;
0197     for (int i = 0; i < NUM_PARTITITONS; i++) {
0198       final long partitionSize = partitionSizesInMergedFile[i];
0199       if (partitionSize > 0) {
0200         FileInputStream fin = new FileInputStream(mergedOutputFile);
0201         fin.getChannel().position(startOffset);
0202         InputStream in = new LimitedInputStream(fin, partitionSize);
0203         in = blockManager.serializerManager().wrapForEncryption(in);
0204         if ((boolean) conf.get(package$.MODULE$.SHUFFLE_COMPRESS())) {
0205           in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
0206         }
0207         try (DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in)) {
0208           Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
0209           while (records.hasNext()) {
0210             Tuple2<Object, Object> record = records.next();
0211             assertEquals(i, hashPartitioner.getPartition(record._1()));
0212             recordsList.add(record);
0213           }
0214         }
0215         startOffset += partitionSize;
0216       }
0217     }
0218     return recordsList;
0219   }
0220 
0221   @Test(expected=IllegalStateException.class)
0222   public void mustCallWriteBeforeSuccessfulStop() throws IOException {
0223     createWriter(false).stop(true);
0224   }
0225 
0226   @Test
0227   public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
0228     createWriter(false).stop(false);
0229   }
0230 
0231   static class PandaException extends RuntimeException {
0232   }
0233 
0234   @Test(expected=PandaException.class)
0235   public void writeFailurePropagates() throws Exception {
0236     class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> {
0237       @Override public boolean hasNext() {
0238         throw new PandaException();
0239       }
0240       @Override public Product2<Object, Object> next() {
0241         return null;
0242       }
0243     }
0244     final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
0245     writer.write(new BadRecords());
0246   }
0247 
0248   @Test
0249   public void writeEmptyIterator() throws Exception {
0250     final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
0251     writer.write(new ArrayList<Product2<Object, Object>>().iterator());
0252     final Option<MapStatus> mapStatus = writer.stop(true);
0253     assertTrue(mapStatus.isDefined());
0254     assertTrue(mergedOutputFile.exists());
0255     assertEquals(0, spillFilesCreated.size());
0256     assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
0257     assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten());
0258     assertEquals(0, taskMetrics.shuffleWriteMetrics().bytesWritten());
0259     assertEquals(0, taskMetrics.diskBytesSpilled());
0260     assertEquals(0, taskMetrics.memoryBytesSpilled());
0261   }
0262 
0263   @Test
0264   public void writeWithoutSpilling() throws Exception {
0265     // In this example, each partition should have exactly one record:
0266     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0267     for (int i = 0; i < NUM_PARTITITONS; i++) {
0268       dataToWrite.add(new Tuple2<>(i, i));
0269     }
0270     final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
0271     writer.write(dataToWrite.iterator());
0272     final Option<MapStatus> mapStatus = writer.stop(true);
0273     assertTrue(mapStatus.isDefined());
0274     assertTrue(mergedOutputFile.exists());
0275 
0276     long sumOfPartitionSizes = 0;
0277     for (long size: partitionSizesInMergedFile) {
0278       // All partitions should be the same size:
0279       assertEquals(partitionSizesInMergedFile[0], size);
0280       sumOfPartitionSizes += size;
0281     }
0282     assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
0283     assertEquals(
0284       HashMultiset.create(dataToWrite),
0285       HashMultiset.create(readRecordsFromFile()));
0286     assertSpillFilesWereCleanedUp();
0287     ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0288     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0289     assertEquals(0, taskMetrics.diskBytesSpilled());
0290     assertEquals(0, taskMetrics.memoryBytesSpilled());
0291     assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0292   }
0293 
0294   private void testMergingSpills(
0295       final boolean transferToEnabled,
0296       String compressionCodecName,
0297       boolean encrypt) throws Exception {
0298     if (compressionCodecName != null) {
0299       conf.set(package$.MODULE$.SHUFFLE_COMPRESS(), true);
0300       conf.set("spark.io.compression.codec", compressionCodecName);
0301     } else {
0302       conf.set(package$.MODULE$.SHUFFLE_COMPRESS(), false);
0303     }
0304     conf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);
0305 
0306     SerializerManager manager;
0307     if (encrypt) {
0308       manager = new SerializerManager(serializer, conf,
0309         Option.apply(CryptoStreamUtils.createKey(conf)));
0310     } else {
0311       manager = new SerializerManager(serializer, conf);
0312     }
0313 
0314     when(blockManager.serializerManager()).thenReturn(manager);
0315     testMergingSpills(transferToEnabled, encrypt);
0316   }
0317 
0318   private void testMergingSpills(
0319       boolean transferToEnabled,
0320       boolean encrypted) throws IOException {
0321     final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
0322     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0323     for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
0324       dataToWrite.add(new Tuple2<>(i, i));
0325     }
0326     writer.insertRecordIntoSorter(dataToWrite.get(0));
0327     writer.insertRecordIntoSorter(dataToWrite.get(1));
0328     writer.insertRecordIntoSorter(dataToWrite.get(2));
0329     writer.insertRecordIntoSorter(dataToWrite.get(3));
0330     writer.forceSorterToSpill();
0331     writer.insertRecordIntoSorter(dataToWrite.get(4));
0332     writer.insertRecordIntoSorter(dataToWrite.get(5));
0333     writer.closeAndWriteOutput();
0334     final Option<MapStatus> mapStatus = writer.stop(true);
0335     assertTrue(mapStatus.isDefined());
0336     assertTrue(mergedOutputFile.exists());
0337     assertEquals(2, spillFilesCreated.size());
0338 
0339     long sumOfPartitionSizes = 0;
0340     for (long size: partitionSizesInMergedFile) {
0341       sumOfPartitionSizes += size;
0342     }
0343 
0344     assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
0345 
0346     assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
0347     assertSpillFilesWereCleanedUp();
0348     ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0349     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0350     assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
0351     assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
0352     assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
0353     assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0354   }
0355 
0356   @Test
0357   public void mergeSpillsWithTransferToAndLZF() throws Exception {
0358     testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
0359   }
0360 
0361   @Test
0362   public void mergeSpillsWithFileStreamAndLZF() throws Exception {
0363     testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
0364   }
0365 
0366   @Test
0367   public void mergeSpillsWithTransferToAndLZ4() throws Exception {
0368     testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
0369   }
0370 
0371   @Test
0372   public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
0373     testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
0374   }
0375 
0376   @Test
0377   public void mergeSpillsWithTransferToAndSnappy() throws Exception {
0378     testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
0379   }
0380 
0381   @Test
0382   public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
0383     testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
0384   }
0385 
0386   @Test
0387   public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
0388     testMergingSpills(true, null, false);
0389   }
0390 
0391   @Test
0392   public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
0393     testMergingSpills(false, null, false);
0394   }
0395 
0396   @Test
0397   public void mergeSpillsWithCompressionAndEncryption() throws Exception {
0398     // This should actually be translated to a "file stream merge" internally, just have the
0399     // test to make sure that it's the case.
0400     testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
0401   }
0402 
0403   @Test
0404   public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
0405     testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
0406   }
0407 
0408   @Test
0409   public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
0410     conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false);
0411     testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
0412   }
0413 
0414   @Test
0415   public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
0416     // This should actually be translated to a "file stream merge" internally, just have the
0417     // test to make sure that it's the case.
0418     testMergingSpills(true, null, true);
0419   }
0420 
0421   @Test
0422   public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception {
0423     testMergingSpills(false, null, true);
0424   }
0425 
0426   @Test
0427   public void writeEnoughDataToTriggerSpill() throws Exception {
0428     memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES);
0429     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0430     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0431     final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10];
0432     for (int i = 0; i < 10 + 1; i++) {
0433       dataToWrite.add(new Tuple2<>(i, bigByteArray));
0434     }
0435     writer.write(dataToWrite.iterator());
0436     assertEquals(2, spillFilesCreated.size());
0437     writer.stop(true);
0438     readRecordsFromFile();
0439     assertSpillFilesWereCleanedUp();
0440     ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0441     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0442     assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
0443     assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
0444     assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
0445     assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0446   }
0447 
0448   @Test
0449   public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOff() throws Exception {
0450     conf.set(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT(), false);
0451     writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
0452     assertEquals(2, spillFilesCreated.size());
0453   }
0454 
0455   @Test
0456   public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() throws Exception {
0457     conf.set(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT(), true);
0458     writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
0459     assertEquals(3, spillFilesCreated.size());
0460   }
0461 
0462   private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
0463     memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
0464     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0465     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0466     for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
0467       dataToWrite.add(new Tuple2<>(i, i));
0468     }
0469     writer.write(dataToWrite.iterator());
0470     writer.stop(true);
0471     readRecordsFromFile();
0472     assertSpillFilesWereCleanedUp();
0473     ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics();
0474     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
0475     assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
0476     assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
0477     assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
0478     assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
0479   }
0480 
0481   @Test
0482   public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
0483     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0484     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0485     final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
0486     new Random(42).nextBytes(bytes);
0487     dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(bytes)));
0488     writer.write(dataToWrite.iterator());
0489     writer.stop(true);
0490     assertEquals(
0491       HashMultiset.create(dataToWrite),
0492       HashMultiset.create(readRecordsFromFile()));
0493     assertSpillFilesWereCleanedUp();
0494   }
0495 
0496   @Test
0497   public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
0498     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0499     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
0500     dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(new byte[1])));
0501     // We should be able to write a record that's right _at_ the max record size
0502     final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4];
0503     new Random(42).nextBytes(atMaxRecordSize);
0504     dataToWrite.add(new Tuple2<>(2, ByteBuffer.wrap(atMaxRecordSize)));
0505     // Inserting a record that's larger than the max record size
0506     final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()];
0507     new Random(42).nextBytes(exceedsMaxRecordSize);
0508     dataToWrite.add(new Tuple2<>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
0509     writer.write(dataToWrite.iterator());
0510     writer.stop(true);
0511     assertEquals(
0512       HashMultiset.create(dataToWrite),
0513       HashMultiset.create(readRecordsFromFile()));
0514     assertSpillFilesWereCleanedUp();
0515   }
0516 
0517   @Test
0518   public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
0519     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
0520     writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
0521     writer.insertRecordIntoSorter(new Tuple2<>(2, 2));
0522     writer.forceSorterToSpill();
0523     writer.insertRecordIntoSorter(new Tuple2<>(2, 2));
0524     writer.stop(false);
0525     assertSpillFilesWereCleanedUp();
0526   }
0527 
0528   @Test
0529   public void testPeakMemoryUsed() throws Exception {
0530     final long recordLengthBytes = 8;
0531     final long pageSizeBytes = 256;
0532     final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
0533     taskMemoryManager = spy(taskMemoryManager);
0534     when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
0535     final UnsafeShuffleWriter<Object, Object> writer = new UnsafeShuffleWriter<>(
0536         blockManager,
0537         taskMemoryManager,
0538         new SerializedShuffleHandle<>(0, shuffleDep),
0539         0L, // map id
0540         taskContext,
0541         conf,
0542         taskContext.taskMetrics().shuffleWriteMetrics(),
0543         new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
0544 
0545     // Peak memory should be monotonically increasing. More specifically, every time
0546     // we allocate a new page it should increase by exactly the size of the page.
0547     long previousPeakMemory = writer.getPeakMemoryUsedBytes();
0548     long newPeakMemory;
0549     try {
0550       for (int i = 0; i < numRecordsPerPage * 10; i++) {
0551         writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
0552         newPeakMemory = writer.getPeakMemoryUsedBytes();
0553         if (i % numRecordsPerPage == 0) {
0554           // The first page is allocated in constructor, another page will be allocated after
0555           // every numRecordsPerPage records (peak memory should change).
0556           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
0557         } else {
0558           assertEquals(previousPeakMemory, newPeakMemory);
0559         }
0560         previousPeakMemory = newPeakMemory;
0561       }
0562 
0563       // Spilling should not change peak memory
0564       writer.forceSorterToSpill();
0565       newPeakMemory = writer.getPeakMemoryUsedBytes();
0566       assertEquals(previousPeakMemory, newPeakMemory);
0567       for (int i = 0; i < numRecordsPerPage; i++) {
0568         writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
0569       }
0570       newPeakMemory = writer.getPeakMemoryUsedBytes();
0571       assertEquals(previousPeakMemory, newPeakMemory);
0572 
0573       // Closing the writer should not change peak memory
0574       writer.closeAndWriteOutput();
0575       newPeakMemory = writer.getPeakMemoryUsedBytes();
0576       assertEquals(previousPeakMemory, newPeakMemory);
0577     } finally {
0578       writer.stop(false);
0579     }
0580   }
0581 
0582 }