Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.shuffle.sort;
0019 
0020 import java.nio.charset.StandardCharsets;
0021 import java.util.Arrays;
0022 import java.util.Random;
0023 
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026 
0027 import org.apache.spark.HashPartitioner;
0028 import org.apache.spark.SparkConf;
0029 import org.apache.spark.internal.config.package$;
0030 import org.apache.spark.memory.MemoryConsumer;
0031 import org.apache.spark.memory.TaskMemoryManager;
0032 import org.apache.spark.memory.TestMemoryConsumer;
0033 import org.apache.spark.memory.TestMemoryManager;
0034 import org.apache.spark.unsafe.Platform;
0035 import org.apache.spark.unsafe.memory.MemoryBlock;
0036 
0037 public class ShuffleInMemorySorterSuite {
0038 
0039   protected boolean shouldUseRadixSort() { return false; }
0040 
0041   final TestMemoryManager memoryManager =
0042     new TestMemoryManager(new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false));
0043   final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0044   final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager);
0045 
0046   private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
0047     final byte[] strBytes = new byte[strLength];
0048     Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
0049     return new String(strBytes, StandardCharsets.UTF_8);
0050   }
0051 
0052   @Test
0053   public void testSortingEmptyInput() {
0054     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
0055       consumer, 100, shouldUseRadixSort());
0056     final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
0057     Assert.assertFalse(iter.hasNext());
0058   }
0059 
0060   @Test
0061   public void testBasicSorting() throws Exception {
0062     final String[] dataToSort = new String[] {
0063       "Boba",
0064       "Pearls",
0065       "Tapioca",
0066       "Taho",
0067       "Condensed Milk",
0068       "Jasmine",
0069       "Milk Tea",
0070       "Lychee",
0071       "Mango"
0072     };
0073     final SparkConf conf = new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false);
0074     final TaskMemoryManager memoryManager =
0075       new TaskMemoryManager(new TestMemoryManager(conf), 0);
0076     final MemoryConsumer c = new TestMemoryConsumer(memoryManager);
0077     final MemoryBlock dataPage = memoryManager.allocatePage(2048, c);
0078     final Object baseObject = dataPage.getBaseObject();
0079     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
0080       consumer, 4, shouldUseRadixSort());
0081     final HashPartitioner hashPartitioner = new HashPartitioner(4);
0082 
0083     // Write the records into the data page and store pointers into the sorter
0084     long position = dataPage.getBaseOffset();
0085     for (String str : dataToSort) {
0086       if (!sorter.hasSpaceForAnotherRecord()) {
0087         sorter.expandPointerArray(
0088           consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
0089       }
0090       final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
0091       final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8);
0092       Platform.putInt(baseObject, position, strBytes.length);
0093       position += 4;
0094       Platform.copyMemory(
0095         strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
0096       position += strBytes.length;
0097       sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
0098     }
0099 
0100     // Sort the records
0101     final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
0102     int prevPartitionId = -1;
0103     Arrays.sort(dataToSort);
0104     for (int i = 0; i < dataToSort.length; i++) {
0105       Assert.assertTrue(iter.hasNext());
0106       iter.loadNext();
0107       final int partitionId = iter.packedRecordPointer.getPartitionId();
0108       Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
0109       Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
0110         partitionId >= prevPartitionId);
0111       final long recordAddress = iter.packedRecordPointer.getRecordPointer();
0112       final int recordLength = Platform.getInt(
0113         memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
0114       final String str = getStringFromDataPage(
0115         memoryManager.getPage(recordAddress),
0116         memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
0117         recordLength);
0118       Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
0119     }
0120     Assert.assertFalse(iter.hasNext());
0121   }
0122 
0123   @Test
0124   public void testSortingManyNumbers() throws Exception {
0125     ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4, shouldUseRadixSort());
0126     int[] numbersToSort = new int[128000];
0127     Random random = new Random(16);
0128     for (int i = 0; i < numbersToSort.length; i++) {
0129       if (!sorter.hasSpaceForAnotherRecord()) {
0130         sorter.expandPointerArray(consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
0131       }
0132       numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
0133       sorter.insertRecord(0, numbersToSort[i]);
0134     }
0135     Arrays.sort(numbersToSort);
0136     int[] sorterResult = new int[numbersToSort.length];
0137     ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
0138     int j = 0;
0139     while (iter.hasNext()) {
0140       iter.loadNext();
0141       sorterResult[j] = iter.packedRecordPointer.getPartitionId();
0142       j += 1;
0143     }
0144     Assert.assertArrayEquals(numbersToSort, sorterResult);
0145   }
0146 }