Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.util.collection.unsafe.sort;
0019 
0020 import java.nio.charset.StandardCharsets;
0021 import java.util.Arrays;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.HashPartitioner;
0027 import org.apache.spark.SparkConf;
0028 import org.apache.spark.memory.TestMemoryConsumer;
0029 import org.apache.spark.memory.TestMemoryManager;
0030 import org.apache.spark.memory.SparkOutOfMemoryError;
0031 import org.apache.spark.memory.TaskMemoryManager;
0032 import org.apache.spark.unsafe.Platform;
0033 import org.apache.spark.unsafe.memory.MemoryBlock;
0034 import org.apache.spark.internal.config.package$;
0035 
0036 import static org.hamcrest.MatcherAssert.assertThat;
0037 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
0038 import static org.hamcrest.Matchers.isIn;
0039 import static org.junit.Assert.assertEquals;
0040 import static org.junit.Assert.fail;
0041 import static org.mockito.Mockito.mock;
0042 
0043 public class UnsafeInMemorySorterSuite {
0044 
0045   protected boolean shouldUseRadixSort() { return false; }
0046 
0047   private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
0048     final byte[] strBytes = new byte[length];
0049     Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length);
0050     return new String(strBytes, StandardCharsets.UTF_8);
0051   }
0052 
0053   @Test
0054   public void testSortingEmptyInput() {
0055     final TaskMemoryManager memoryManager = new TaskMemoryManager(
0056       new TestMemoryManager(
0057         new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)), 0);
0058     final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
0059     final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer,
0060       memoryManager,
0061       mock(RecordComparator.class),
0062       mock(PrefixComparator.class),
0063       100,
0064       shouldUseRadixSort());
0065     final UnsafeSorterIterator iter = sorter.getSortedIterator();
0066     Assert.assertFalse(iter.hasNext());
0067   }
0068 
0069   @Test
0070   public void testSortingOnlyByIntegerPrefix() throws Exception {
0071     final String[] dataToSort = new String[] {
0072       "Boba",
0073       "Pearls",
0074       "Tapioca",
0075       "Taho",
0076       "Condensed Milk",
0077       "Jasmine",
0078       "Milk Tea",
0079       "Lychee",
0080       "Mango"
0081     };
0082     final TaskMemoryManager memoryManager = new TaskMemoryManager(
0083       new TestMemoryManager(
0084         new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)), 0);
0085     final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
0086     final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer);
0087     final Object baseObject = dataPage.getBaseObject();
0088     // Write the records into the data page:
0089     long position = dataPage.getBaseOffset();
0090     for (String str : dataToSort) {
0091       final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8);
0092       Platform.putInt(baseObject, position, strBytes.length);
0093       position += 4;
0094       Platform.copyMemory(
0095         strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
0096       position += strBytes.length;
0097     }
0098     // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
0099     // use a dummy comparator
0100     final RecordComparator recordComparator = new RecordComparator() {
0101       @Override
0102       public int compare(
0103         Object leftBaseObject,
0104         long leftBaseOffset,
0105         int leftBaseLength,
0106         Object rightBaseObject,
0107         long rightBaseOffset,
0108         int rightBaseLength) {
0109         return 0;
0110       }
0111     };
0112     // Compute key prefixes based on the records' partition ids
0113     final HashPartitioner hashPartitioner = new HashPartitioner(4);
0114     // Use integer comparison for comparing prefixes (which are partition ids, in this case)
0115     final PrefixComparator prefixComparator = PrefixComparators.LONG;
0116     UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager,
0117       recordComparator, prefixComparator, dataToSort.length, shouldUseRadixSort());
0118     // Given a page of records, insert those records into the sorter one-by-one:
0119     position = dataPage.getBaseOffset();
0120     for (int i = 0; i < dataToSort.length; i++) {
0121       if (!sorter.hasSpaceForAnotherRecord()) {
0122         sorter.expandPointerArray(
0123           consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
0124       }
0125       // position now points to the start of a record (which holds its length).
0126       final int recordLength = Platform.getInt(baseObject, position);
0127       final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
0128       final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
0129       final int partitionId = hashPartitioner.getPartition(str);
0130       sorter.insertRecord(address, partitionId, false);
0131       position += 4 + recordLength;
0132     }
0133     final UnsafeSorterIterator iter = sorter.getSortedIterator();
0134     int iterLength = 0;
0135     long prevPrefix = -1;
0136     while (iter.hasNext()) {
0137       iter.loadNext();
0138       final String str =
0139         getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength());
0140       final long keyPrefix = iter.getKeyPrefix();
0141       assertThat(str, isIn(Arrays.asList(dataToSort)));
0142       assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
0143       prevPrefix = keyPrefix;
0144       iterLength++;
0145     }
0146     assertEquals(dataToSort.length, iterLength);
0147   }
0148 
0149   @Test
0150   public void freeAfterOOM() {
0151     final SparkConf sparkConf = new SparkConf();
0152     sparkConf.set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false);
0153 
0154     final TestMemoryManager testMemoryManager =
0155             new TestMemoryManager(sparkConf);
0156     final TaskMemoryManager memoryManager = new TaskMemoryManager(
0157             testMemoryManager, 0);
0158     final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
0159     final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer);
0160     final Object baseObject = dataPage.getBaseObject();
0161     // Write the records into the data page:
0162     long position = dataPage.getBaseOffset();
0163 
0164     final HashPartitioner hashPartitioner = new HashPartitioner(4);
0165     // Use integer comparison for comparing prefixes (which are partition ids, in this case)
0166     final PrefixComparator prefixComparator = PrefixComparators.LONG;
0167     final RecordComparator recordComparator = new RecordComparator() {
0168       @Override
0169       public int compare(
0170               Object leftBaseObject,
0171               long leftBaseOffset,
0172               int leftBaseLength,
0173               Object rightBaseObject,
0174               long rightBaseOffset,
0175               int rightBaseLength) {
0176         return 0;
0177       }
0178     };
0179     UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager,
0180             recordComparator, prefixComparator, 100, shouldUseRadixSort());
0181 
0182     testMemoryManager.markExecutionAsOutOfMemoryOnce();
0183     try {
0184       sorter.reset();
0185       fail("expected SparkOutOfMemoryError but it seems operation surprisingly succeeded");
0186     } catch (SparkOutOfMemoryError oom) {
0187       // as expected
0188     }
0189     // [SPARK-21907] this failed on NPE at
0190     // org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108)
0191     sorter.free();
0192     // simulate a 'back to back' free.
0193     sorter.free();
0194   }
0195 
0196 }