0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.shuffle.sort;
0019
0020 import java.nio.charset.StandardCharsets;
0021 import java.util.Arrays;
0022 import java.util.Random;
0023
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026
0027 import org.apache.spark.HashPartitioner;
0028 import org.apache.spark.SparkConf;
0029 import org.apache.spark.internal.config.package$;
0030 import org.apache.spark.memory.MemoryConsumer;
0031 import org.apache.spark.memory.TaskMemoryManager;
0032 import org.apache.spark.memory.TestMemoryConsumer;
0033 import org.apache.spark.memory.TestMemoryManager;
0034 import org.apache.spark.unsafe.Platform;
0035 import org.apache.spark.unsafe.memory.MemoryBlock;
0036
0037 public class ShuffleInMemorySorterSuite {
0038
0039 protected boolean shouldUseRadixSort() { return false; }
0040
0041 final TestMemoryManager memoryManager =
0042 new TestMemoryManager(new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false));
0043 final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0044 final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager);
0045
0046 private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
0047 final byte[] strBytes = new byte[strLength];
0048 Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
0049 return new String(strBytes, StandardCharsets.UTF_8);
0050 }
0051
0052 @Test
0053 public void testSortingEmptyInput() {
0054 final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
0055 consumer, 100, shouldUseRadixSort());
0056 final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
0057 Assert.assertFalse(iter.hasNext());
0058 }
0059
0060 @Test
0061 public void testBasicSorting() throws Exception {
0062 final String[] dataToSort = new String[] {
0063 "Boba",
0064 "Pearls",
0065 "Tapioca",
0066 "Taho",
0067 "Condensed Milk",
0068 "Jasmine",
0069 "Milk Tea",
0070 "Lychee",
0071 "Mango"
0072 };
0073 final SparkConf conf = new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false);
0074 final TaskMemoryManager memoryManager =
0075 new TaskMemoryManager(new TestMemoryManager(conf), 0);
0076 final MemoryConsumer c = new TestMemoryConsumer(memoryManager);
0077 final MemoryBlock dataPage = memoryManager.allocatePage(2048, c);
0078 final Object baseObject = dataPage.getBaseObject();
0079 final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
0080 consumer, 4, shouldUseRadixSort());
0081 final HashPartitioner hashPartitioner = new HashPartitioner(4);
0082
0083
0084 long position = dataPage.getBaseOffset();
0085 for (String str : dataToSort) {
0086 if (!sorter.hasSpaceForAnotherRecord()) {
0087 sorter.expandPointerArray(
0088 consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
0089 }
0090 final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
0091 final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8);
0092 Platform.putInt(baseObject, position, strBytes.length);
0093 position += 4;
0094 Platform.copyMemory(
0095 strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
0096 position += strBytes.length;
0097 sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
0098 }
0099
0100
0101 final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
0102 int prevPartitionId = -1;
0103 Arrays.sort(dataToSort);
0104 for (int i = 0; i < dataToSort.length; i++) {
0105 Assert.assertTrue(iter.hasNext());
0106 iter.loadNext();
0107 final int partitionId = iter.packedRecordPointer.getPartitionId();
0108 Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
0109 Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
0110 partitionId >= prevPartitionId);
0111 final long recordAddress = iter.packedRecordPointer.getRecordPointer();
0112 final int recordLength = Platform.getInt(
0113 memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
0114 final String str = getStringFromDataPage(
0115 memoryManager.getPage(recordAddress),
0116 memoryManager.getOffsetInPage(recordAddress) + 4,
0117 recordLength);
0118 Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
0119 }
0120 Assert.assertFalse(iter.hasNext());
0121 }
0122
0123 @Test
0124 public void testSortingManyNumbers() throws Exception {
0125 ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4, shouldUseRadixSort());
0126 int[] numbersToSort = new int[128000];
0127 Random random = new Random(16);
0128 for (int i = 0; i < numbersToSort.length; i++) {
0129 if (!sorter.hasSpaceForAnotherRecord()) {
0130 sorter.expandPointerArray(consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
0131 }
0132 numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
0133 sorter.insertRecord(0, numbersToSort[i]);
0134 }
0135 Arrays.sort(numbersToSort);
0136 int[] sorterResult = new int[numbersToSort.length];
0137 ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
0138 int j = 0;
0139 while (iter.hasNext()) {
0140 iter.loadNext();
0141 sorterResult[j] = iter.packedRecordPointer.getPartitionId();
0142 j += 1;
0143 }
0144 Assert.assertArrayEquals(numbersToSort, sorterResult);
0145 }
0146 }