Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.unsafe.map;
0019 
0020 import java.io.File;
0021 import java.io.IOException;
0022 import java.nio.ByteBuffer;
0023 import java.util.*;
0024 
0025 import scala.Tuple2$;
0026 
0027 import org.junit.After;
0028 import org.junit.Assert;
0029 import org.junit.Before;
0030 import org.junit.Test;
0031 import org.mockito.Mock;
0032 import org.mockito.MockitoAnnotations;
0033 
0034 import org.apache.spark.SparkConf;
0035 import org.apache.spark.executor.ShuffleWriteMetrics;
0036 import org.apache.spark.memory.MemoryMode;
0037 import org.apache.spark.memory.SparkOutOfMemoryError;
0038 import org.apache.spark.memory.TestMemoryConsumer;
0039 import org.apache.spark.memory.TaskMemoryManager;
0040 import org.apache.spark.memory.TestMemoryManager;
0041 import org.apache.spark.network.util.JavaUtils;
0042 import org.apache.spark.serializer.JavaSerializer;
0043 import org.apache.spark.serializer.SerializerInstance;
0044 import org.apache.spark.serializer.SerializerManager;
0045 import org.apache.spark.storage.*;
0046 import org.apache.spark.unsafe.Platform;
0047 import org.apache.spark.unsafe.array.ByteArrayMethods;
0048 import org.apache.spark.util.Utils;
0049 import org.apache.spark.internal.config.package$;
0050 
0051 import static org.hamcrest.Matchers.greaterThan;
0052 import static org.junit.Assert.assertEquals;
0053 import static org.junit.Assert.assertFalse;
0054 import static org.mockito.Answers.RETURNS_SMART_NULLS;
0055 import static org.mockito.ArgumentMatchers.any;
0056 import static org.mockito.ArgumentMatchers.anyInt;
0057 import static org.mockito.Mockito.when;
0058 
0059 
0060 public abstract class AbstractBytesToBytesMapSuite {
0061 
0062   private final Random rand = new Random(42);
0063 
0064   private TestMemoryManager memoryManager;
0065   private TaskMemoryManager taskMemoryManager;
0066   private SerializerManager serializerManager = new SerializerManager(
0067       new JavaSerializer(new SparkConf()),
0068       new SparkConf().set(package$.MODULE$.SHUFFLE_SPILL_COMPRESS(), false));
0069   private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
0070 
0071   final LinkedList<File> spillFilesCreated = new LinkedList<>();
0072   File tempDir;
0073 
0074   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
0075   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
0076 
0077   @Before
0078   public void setup() {
0079     memoryManager =
0080       new TestMemoryManager(
0081         new SparkConf()
0082           .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), useOffHeapMemoryAllocator())
0083           .set(package$.MODULE$.MEMORY_OFFHEAP_SIZE(), 256 * 1024 * 1024L)
0084           .set(package$.MODULE$.SHUFFLE_SPILL_COMPRESS(), false)
0085           .set(package$.MODULE$.SHUFFLE_COMPRESS(), false));
0086     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0087 
0088     tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
0089     spillFilesCreated.clear();
0090     MockitoAnnotations.initMocks(this);
0091     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
0092     when(diskBlockManager.createTempLocalBlock()).thenAnswer(invocationOnMock -> {
0093       TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
0094       File file = File.createTempFile("spillFile", ".spill", tempDir);
0095       spillFilesCreated.add(file);
0096       return Tuple2$.MODULE$.apply(blockId, file);
0097     });
0098     when(blockManager.getDiskWriter(
0099       any(BlockId.class),
0100       any(File.class),
0101       any(SerializerInstance.class),
0102       anyInt(),
0103       any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> {
0104         Object[] args = invocationOnMock.getArguments();
0105 
0106         return new DiskBlockObjectWriter(
0107           (File) args[1],
0108           serializerManager,
0109           (SerializerInstance) args[2],
0110           (Integer) args[3],
0111           false,
0112           (ShuffleWriteMetrics) args[4],
0113           (BlockId) args[0]
0114         );
0115       });
0116   }
0117 
0118   @After
0119   public void tearDown() {
0120     Utils.deleteRecursively(tempDir);
0121     tempDir = null;
0122 
0123     if (taskMemoryManager != null) {
0124       Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
0125       long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
0126       taskMemoryManager = null;
0127       Assert.assertEquals(0L, leakedMemory);
0128     }
0129   }
0130 
0131   protected abstract boolean useOffHeapMemoryAllocator();
0132 
0133   private static byte[] getByteArray(Object base, long offset, int size) {
0134     final byte[] arr = new byte[size];
0135     Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size);
0136     return arr;
0137   }
0138 
0139   private byte[] getRandomByteArray(int numWords) {
0140     Assert.assertTrue(numWords >= 0);
0141     final int lengthInBytes = numWords * 8;
0142     final byte[] bytes = new byte[lengthInBytes];
0143     rand.nextBytes(bytes);
0144     return bytes;
0145   }
0146 
0147   /**
0148    * Fast equality checking for byte arrays, since these comparisons are a bottleneck
0149    * in our stress tests.
0150    */
0151   private static boolean arrayEquals(
0152       byte[] expected,
0153       Object base,
0154       long offset,
0155       long actualLengthBytes) {
0156     return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
0157       expected,
0158       Platform.BYTE_ARRAY_OFFSET,
0159       base,
0160       offset,
0161       expected.length
0162     );
0163   }
0164 
0165   @Test
0166   public void emptyMap() {
0167     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
0168     try {
0169       Assert.assertEquals(0, map.numKeys());
0170       final int keyLengthInWords = 10;
0171       final int keyLengthInBytes = keyLengthInWords * 8;
0172       final byte[] key = getRandomByteArray(keyLengthInWords);
0173       Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
0174       Assert.assertFalse(map.iterator().hasNext());
0175     } finally {
0176       map.free();
0177     }
0178   }
0179 
0180   @Test
0181   public void setAndRetrieveAKey() {
0182     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
0183     final int recordLengthWords = 10;
0184     final int recordLengthBytes = recordLengthWords * 8;
0185     final byte[] keyData = getRandomByteArray(recordLengthWords);
0186     final byte[] valueData = getRandomByteArray(recordLengthWords);
0187     try {
0188       final BytesToBytesMap.Location loc =
0189         map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes);
0190       Assert.assertFalse(loc.isDefined());
0191       Assert.assertTrue(loc.append(
0192         keyData,
0193         Platform.BYTE_ARRAY_OFFSET,
0194         recordLengthBytes,
0195         valueData,
0196         Platform.BYTE_ARRAY_OFFSET,
0197         recordLengthBytes
0198       ));
0199       // After storing the key and value, the other location methods should return results that
0200       // reflect the result of this store without us having to call lookup() again on the same key.
0201       Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
0202       Assert.assertEquals(recordLengthBytes, loc.getValueLength());
0203       Assert.assertArrayEquals(keyData,
0204         getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
0205       Assert.assertArrayEquals(valueData,
0206         getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
0207 
0208       // After calling lookup() the location should still point to the correct data.
0209       Assert.assertTrue(
0210         map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
0211       Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
0212       Assert.assertEquals(recordLengthBytes, loc.getValueLength());
0213       Assert.assertArrayEquals(keyData,
0214         getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
0215       Assert.assertArrayEquals(valueData,
0216         getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
0217 
0218       try {
0219         Assert.assertTrue(loc.append(
0220           keyData,
0221           Platform.BYTE_ARRAY_OFFSET,
0222           recordLengthBytes,
0223           valueData,
0224           Platform.BYTE_ARRAY_OFFSET,
0225           recordLengthBytes
0226         ));
0227         Assert.fail("Should not be able to set a new value for a key");
0228       } catch (AssertionError e) {
0229         // Expected exception; do nothing.
0230       }
0231     } finally {
0232       map.free();
0233     }
0234   }
0235 
0236   private void iteratorTestBase(boolean destructive) throws Exception {
0237     final int size = 4096;
0238     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES);
0239     try {
0240       for (long i = 0; i < size; i++) {
0241         final long[] value = new long[] { i };
0242         final BytesToBytesMap.Location loc =
0243           map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8);
0244         Assert.assertFalse(loc.isDefined());
0245         // Ensure that we store some zero-length keys
0246         if (i % 5 == 0) {
0247           Assert.assertTrue(loc.append(
0248             null,
0249             Platform.LONG_ARRAY_OFFSET,
0250             0,
0251             value,
0252             Platform.LONG_ARRAY_OFFSET,
0253             8
0254           ));
0255         } else {
0256           Assert.assertTrue(loc.append(
0257             value,
0258             Platform.LONG_ARRAY_OFFSET,
0259             8,
0260             value,
0261             Platform.LONG_ARRAY_OFFSET,
0262             8
0263           ));
0264         }
0265       }
0266       final java.util.BitSet valuesSeen = new java.util.BitSet(size);
0267       final Iterator<BytesToBytesMap.Location> iter;
0268       if (destructive) {
0269         iter = map.destructiveIterator();
0270       } else {
0271         iter = map.iterator();
0272       }
0273       int numPages = map.getNumDataPages();
0274       int countFreedPages = 0;
0275       while (iter.hasNext()) {
0276         final BytesToBytesMap.Location loc = iter.next();
0277         Assert.assertTrue(loc.isDefined());
0278         final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset());
0279         final long keyLength = loc.getKeyLength();
0280         if (keyLength == 0) {
0281           Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
0282         } else {
0283           final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset());
0284           Assert.assertEquals(value, key);
0285         }
0286         valuesSeen.set((int) value);
0287         if (destructive) {
0288           // The iterator moves onto next page and frees previous page
0289           if (map.getNumDataPages() < numPages) {
0290             numPages = map.getNumDataPages();
0291             countFreedPages++;
0292           }
0293         }
0294       }
0295       if (destructive) {
0296         // Latest page is not freed by iterator but by map itself
0297         Assert.assertEquals(countFreedPages, numPages - 1);
0298       }
0299       Assert.assertEquals(size, valuesSeen.cardinality());
0300     } finally {
0301       map.free();
0302     }
0303   }
0304 
0305   @Test
0306   public void iteratorTest() throws Exception {
0307     iteratorTestBase(false);
0308   }
0309 
0310   @Test
0311   public void destructiveIteratorTest() throws Exception {
0312     iteratorTestBase(true);
0313   }
0314 
0315   @Test
0316   public void iteratingOverDataPagesWithWastedSpace() throws Exception {
0317     final int NUM_ENTRIES = 1000 * 1000;
0318     final int KEY_LENGTH = 24;
0319     final int VALUE_LENGTH = 40;
0320     final BytesToBytesMap map =
0321       new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
0322     // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte
0323     // pages won't be evenly-divisible by records of this size, which will cause us to waste some
0324     // space at the end of the page. This is necessary in order for us to take the end-of-record
0325     // handling branch in iterator().
0326     try {
0327       for (int i = 0; i < NUM_ENTRIES; i++) {
0328         final long[] key = new long[] { i, i, i };  // 3 * 8 = 24 bytes
0329         final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
0330         final BytesToBytesMap.Location loc = map.lookup(
0331           key,
0332           Platform.LONG_ARRAY_OFFSET,
0333           KEY_LENGTH
0334         );
0335         Assert.assertFalse(loc.isDefined());
0336         Assert.assertTrue(loc.append(
0337           key,
0338           Platform.LONG_ARRAY_OFFSET,
0339           KEY_LENGTH,
0340           value,
0341           Platform.LONG_ARRAY_OFFSET,
0342           VALUE_LENGTH
0343         ));
0344       }
0345       Assert.assertEquals(2, map.getNumDataPages());
0346 
0347       final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES);
0348       final Iterator<BytesToBytesMap.Location> iter = map.iterator();
0349       final long[] key = new long[KEY_LENGTH / 8];
0350       final long[] value = new long[VALUE_LENGTH / 8];
0351       while (iter.hasNext()) {
0352         final BytesToBytesMap.Location loc = iter.next();
0353         Assert.assertTrue(loc.isDefined());
0354         Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
0355         Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
0356         Platform.copyMemory(
0357           loc.getKeyBase(),
0358           loc.getKeyOffset(),
0359           key,
0360           Platform.LONG_ARRAY_OFFSET,
0361           KEY_LENGTH
0362         );
0363         Platform.copyMemory(
0364           loc.getValueBase(),
0365           loc.getValueOffset(),
0366           value,
0367           Platform.LONG_ARRAY_OFFSET,
0368           VALUE_LENGTH
0369         );
0370         for (long j : key) {
0371           Assert.assertEquals(key[0], j);
0372         }
0373         for (long j : value) {
0374           Assert.assertEquals(key[0], j);
0375         }
0376         valuesSeen.set((int) key[0]);
0377       }
0378       Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality());
0379     } finally {
0380       map.free();
0381     }
0382   }
0383 
0384   @Test
0385   public void randomizedStressTest() {
0386     final int size = 32768;
0387     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
0388     // into ByteBuffers in order to use them as keys here.
0389     final Map<ByteBuffer, byte[]> expected = new HashMap<>();
0390     final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES);
0391     try {
0392       // Fill the map to 90% full so that we can trigger probing
0393       for (int i = 0; i < size * 0.9; i++) {
0394         final byte[] key = getRandomByteArray(rand.nextInt(256) + 1);
0395         final byte[] value = getRandomByteArray(rand.nextInt(256) + 1);
0396         if (!expected.containsKey(ByteBuffer.wrap(key))) {
0397           expected.put(ByteBuffer.wrap(key), value);
0398           final BytesToBytesMap.Location loc = map.lookup(
0399             key,
0400             Platform.BYTE_ARRAY_OFFSET,
0401             key.length
0402           );
0403           Assert.assertFalse(loc.isDefined());
0404           Assert.assertTrue(loc.append(
0405             key,
0406             Platform.BYTE_ARRAY_OFFSET,
0407             key.length,
0408             value,
0409             Platform.BYTE_ARRAY_OFFSET,
0410             value.length
0411           ));
0412           // After calling putNewKey, the following should be true, even before calling
0413           // lookup():
0414           Assert.assertTrue(loc.isDefined());
0415           Assert.assertEquals(key.length, loc.getKeyLength());
0416           Assert.assertEquals(value.length, loc.getValueLength());
0417           Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
0418           Assert.assertTrue(
0419             arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
0420         }
0421       }
0422 
0423       for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
0424         final byte[] key = JavaUtils.bufferToArray(entry.getKey());
0425         final byte[] value = entry.getValue();
0426         final BytesToBytesMap.Location loc =
0427           map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
0428         Assert.assertTrue(loc.isDefined());
0429         Assert.assertTrue(
0430           arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
0431         Assert.assertTrue(
0432           arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
0433       }
0434     } finally {
0435       map.free();
0436     }
0437   }
0438 
0439   @Test
0440   public void randomizedTestWithRecordsLargerThanPageSize() {
0441     final long pageSizeBytes = 128;
0442     final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes);
0443     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
0444     // into ByteBuffers in order to use them as keys here.
0445     final Map<ByteBuffer, byte[]> expected = new HashMap<>();
0446     try {
0447       for (int i = 0; i < 1000; i++) {
0448         final byte[] key = getRandomByteArray(rand.nextInt(128));
0449         final byte[] value = getRandomByteArray(rand.nextInt(128));
0450         if (!expected.containsKey(ByteBuffer.wrap(key))) {
0451           expected.put(ByteBuffer.wrap(key), value);
0452           final BytesToBytesMap.Location loc = map.lookup(
0453             key,
0454             Platform.BYTE_ARRAY_OFFSET,
0455             key.length
0456           );
0457           Assert.assertFalse(loc.isDefined());
0458           Assert.assertTrue(loc.append(
0459             key,
0460             Platform.BYTE_ARRAY_OFFSET,
0461             key.length,
0462             value,
0463             Platform.BYTE_ARRAY_OFFSET,
0464             value.length
0465           ));
0466           // After calling putNewKey, the following should be true, even before calling
0467           // lookup():
0468           Assert.assertTrue(loc.isDefined());
0469           Assert.assertEquals(key.length, loc.getKeyLength());
0470           Assert.assertEquals(value.length, loc.getValueLength());
0471           Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
0472           Assert.assertTrue(
0473             arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
0474         }
0475       }
0476       for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
0477         final byte[] key = JavaUtils.bufferToArray(entry.getKey());
0478         final byte[] value = entry.getValue();
0479         final BytesToBytesMap.Location loc =
0480           map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
0481         Assert.assertTrue(loc.isDefined());
0482         Assert.assertTrue(
0483           arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
0484         Assert.assertTrue(
0485           arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
0486       }
0487     } finally {
0488       map.free();
0489     }
0490   }
0491 
0492   @Test
0493   public void failureToAllocateFirstPage() {
0494     memoryManager.limit(1024);  // longArray
0495     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
0496     try {
0497       final long[] emptyArray = new long[0];
0498       final BytesToBytesMap.Location loc =
0499         map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0);
0500       Assert.assertFalse(loc.isDefined());
0501       Assert.assertFalse(loc.append(
0502         emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0));
0503     } finally {
0504       map.free();
0505     }
0506   }
0507 
0508 
0509   @Test
0510   public void failureToGrow() {
0511     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024);
0512     try {
0513       boolean success = true;
0514       int i;
0515       for (i = 0; i < 127; i++) {
0516         if (i > 0) {
0517           memoryManager.limit(0);
0518         }
0519         final long[] arr = new long[]{i};
0520         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
0521         success =
0522           loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
0523         if (!success) {
0524           break;
0525         }
0526       }
0527       Assert.assertThat(i, greaterThan(0));
0528       Assert.assertFalse(success);
0529     } finally {
0530       map.free();
0531     }
0532   }
0533 
0534   @Test
0535   public void spillInIterator() throws IOException {
0536     BytesToBytesMap map = new BytesToBytesMap(
0537       taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024);
0538     try {
0539       int i;
0540       for (i = 0; i < 1024; i++) {
0541         final long[] arr = new long[]{i};
0542         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
0543         loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
0544       }
0545       BytesToBytesMap.MapIterator iter = map.iterator();
0546       for (i = 0; i < 100; i++) {
0547         iter.next();
0548       }
0549       // Non-destructive iterator is not spillable
0550       Assert.assertEquals(0, iter.spill(1024L * 10));
0551       for (i = 100; i < 1024; i++) {
0552         iter.next();
0553       }
0554 
0555       BytesToBytesMap.MapIterator iter2 = map.destructiveIterator();
0556       for (i = 0; i < 100; i++) {
0557         iter2.next();
0558       }
0559       Assert.assertTrue(iter2.spill(1024) >= 1024);
0560       for (i = 100; i < 1024; i++) {
0561         iter2.next();
0562       }
0563       assertFalse(iter2.hasNext());
0564     } finally {
0565       map.free();
0566       for (File spillFile : spillFilesCreated) {
0567         assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
0568           spillFile.exists());
0569       }
0570     }
0571   }
0572 
0573   @Test
0574   public void multipleValuesForSameKey() {
0575     BytesToBytesMap map =
0576       new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024);
0577     try {
0578       int i;
0579       for (i = 0; i < 1024; i++) {
0580         final long[] arr = new long[]{i};
0581         map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
0582           .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
0583       }
0584       assert map.numKeys() == 1024;
0585       assert map.numValues() == 1024;
0586       for (i = 0; i < 1024; i++) {
0587         final long[] arr = new long[]{i};
0588         map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
0589           .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
0590       }
0591       assert map.numKeys() == 1024;
0592       assert map.numValues() == 2048;
0593       for (i = 0; i < 1024; i++) {
0594         final long[] arr = new long[]{i};
0595         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
0596         assert loc.isDefined();
0597         assert loc.nextValue();
0598         assert !loc.nextValue();
0599       }
0600       BytesToBytesMap.MapIterator iter = map.iterator();
0601       for (i = 0; i < 2048; i++) {
0602         assert iter.hasNext();
0603         final BytesToBytesMap.Location loc = iter.next();
0604         assert loc.isDefined();
0605       }
0606     } finally {
0607       map.free();
0608     }
0609   }
0610 
0611   @Test
0612   public void initialCapacityBoundsChecking() {
0613     try {
0614       new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
0615       Assert.fail("Expected IllegalArgumentException to be thrown");
0616     } catch (IllegalArgumentException e) {
0617       // expected exception
0618     }
0619 
0620     try {
0621       new BytesToBytesMap(
0622         taskMemoryManager,
0623         BytesToBytesMap.MAX_CAPACITY + 1,
0624         PAGE_SIZE_BYTES);
0625       Assert.fail("Expected IllegalArgumentException to be thrown");
0626     } catch (IllegalArgumentException e) {
0627       // expected exception
0628     }
0629 
0630     try {
0631       new BytesToBytesMap(
0632         taskMemoryManager,
0633         1,
0634         TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES + 1);
0635       Assert.fail("Expected IllegalArgumentException to be thrown");
0636     } catch (IllegalArgumentException e) {
0637       // expected exception
0638     }
0639 
0640   }
0641 
0642   @Test
0643   public void testPeakMemoryUsed() {
0644     final long recordLengthBytes = 32;
0645     final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
0646     final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
0647     final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
0648 
0649     // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
0650     // monotonically increasing. More specifically, every time we allocate a new page it
0651     // should increase by exactly the size of the page. In this regard, the memory usage
0652     // at any given time is also the peak memory used.
0653     long previousPeakMemory = map.getPeakMemoryUsedBytes();
0654     long newPeakMemory;
0655     try {
0656       for (long i = 0; i < numRecordsPerPage * 10; i++) {
0657         final long[] value = new long[]{i};
0658         map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).append(
0659           value,
0660           Platform.LONG_ARRAY_OFFSET,
0661           8,
0662           value,
0663           Platform.LONG_ARRAY_OFFSET,
0664           8);
0665         newPeakMemory = map.getPeakMemoryUsedBytes();
0666         if (i % numRecordsPerPage == 0) {
0667           // We allocated a new page for this record, so peak memory should change
0668           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
0669         } else {
0670           assertEquals(previousPeakMemory, newPeakMemory);
0671         }
0672         previousPeakMemory = newPeakMemory;
0673       }
0674 
0675       // Freeing the map should not change the peak memory
0676       map.free();
0677       newPeakMemory = map.getPeakMemoryUsedBytes();
0678       assertEquals(previousPeakMemory, newPeakMemory);
0679 
0680     } finally {
0681       map.free();
0682     }
0683   }
0684 
0685   @Test
0686   public void avoidDeadlock() throws InterruptedException {
0687     memoryManager.limit(PAGE_SIZE_BYTES);
0688     MemoryMode mode = useOffHeapMemoryAllocator() ? MemoryMode.OFF_HEAP: MemoryMode.ON_HEAP;
0689     TestMemoryConsumer c1 = new TestMemoryConsumer(taskMemoryManager, mode);
0690     BytesToBytesMap map =
0691       new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024);
0692 
0693     Thread thread = new Thread(() -> {
0694       int i = 0;
0695       while (i < 10) {
0696         c1.use(10000000);
0697         i++;
0698       }
0699       c1.free(c1.getUsed());
0700     });
0701 
0702     try {
0703       int i;
0704       for (i = 0; i < 1024; i++) {
0705         final long[] arr = new long[]{i};
0706         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
0707         loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
0708       }
0709 
0710       // Starts to require memory at another memory consumer.
0711       thread.start();
0712 
0713       BytesToBytesMap.MapIterator iter = map.destructiveIterator();
0714       for (i = 0; i < 1024; i++) {
0715         iter.next();
0716       }
0717       assertFalse(iter.hasNext());
0718     } finally {
0719       map.free();
0720       thread.join();
0721       for (File spillFile : spillFilesCreated) {
0722         assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
0723           spillFile.exists());
0724       }
0725     }
0726   }
0727 
0728   @Test
0729   public void freeAfterFailedReset() {
0730     // SPARK-29244: BytesToBytesMap.free after a OOM reset operation should not cause failure.
0731     memoryManager.limit(5000);
0732     BytesToBytesMap map =
0733       new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 256, 0.5, 4000);
0734     // Force OOM on next memory allocation.
0735     memoryManager.markExecutionAsOutOfMemoryOnce();
0736     try {
0737       map.reset();
0738       Assert.fail("Expected SparkOutOfMemoryError to be thrown");
0739     } catch (SparkOutOfMemoryError e) {
0740       // Expected exception; do nothing.
0741     } finally {
0742       map.free();
0743     }
0744   }
0745 
0746 }