Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql.execution.sort;
0019 
0020 import org.apache.spark.SparkConf;
0021 import org.apache.spark.internal.config.package$;
0022 import org.apache.spark.memory.TaskMemoryManager;
0023 import org.apache.spark.memory.TestMemoryConsumer;
0024 import org.apache.spark.memory.TestMemoryManager;
0025 import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
0026 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
0027 import org.apache.spark.sql.execution.RecordBinaryComparator;
0028 import org.apache.spark.unsafe.Platform;
0029 import org.apache.spark.unsafe.UnsafeAlignedOffset;
0030 import org.apache.spark.unsafe.array.LongArray;
0031 import org.apache.spark.unsafe.memory.MemoryBlock;
0032 import org.apache.spark.unsafe.types.UTF8String;
0033 import org.apache.spark.util.collection.unsafe.sort.*;
0034 
0035 import org.junit.After;
0036 import org.junit.Assert;
0037 import org.junit.Before;
0038 import org.junit.Test;
0039 
0040 /**
0041  * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form.
0042  */
0043 public class RecordBinaryComparatorSuite {
0044 
0045   private final TaskMemoryManager memoryManager = new TaskMemoryManager(
0046       new TestMemoryManager(
0047         new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)), 0);
0048   private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
0049 
0050   private final int uaoSize = UnsafeAlignedOffset.getUaoSize();
0051 
0052   private MemoryBlock dataPage;
0053   private long pageCursor;
0054 
0055   private LongArray array;
0056   private int pos;
0057 
0058   @Before
0059   public void beforeEach() {
0060     // Only compare between two input rows.
0061     array = consumer.allocateArray(2);
0062     pos = 0;
0063 
0064     dataPage = memoryManager.allocatePage(4096, consumer);
0065     pageCursor = dataPage.getBaseOffset();
0066   }
0067 
0068   @After
0069   public void afterEach() {
0070     consumer.freePage(dataPage);
0071     dataPage = null;
0072     pageCursor = 0;
0073 
0074     consumer.freeArray(array);
0075     array = null;
0076     pos = 0;
0077   }
0078 
0079   private void insertRow(UnsafeRow row) {
0080     Object recordBase = row.getBaseObject();
0081     long recordOffset = row.getBaseOffset();
0082     int recordLength = row.getSizeInBytes();
0083 
0084     Object baseObject = dataPage.getBaseObject();
0085     Assert.assertTrue(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size());
0086     long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor);
0087     UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
0088     pageCursor += uaoSize;
0089     Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength);
0090     pageCursor += recordLength;
0091 
0092     Assert.assertTrue(pos < 2);
0093     array.set(pos, recordAddress);
0094     pos++;
0095   }
0096 
0097   private int compare(int index1, int index2) {
0098     Object baseObject = dataPage.getBaseObject();
0099 
0100     long recordAddress1 = array.get(index1);
0101     long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize;
0102     int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize);
0103 
0104     long recordAddress2 = array.get(index2);
0105     long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize;
0106     int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize);
0107 
0108     return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject,
0109         baseOffset2, recordLength2);
0110   }
0111 
0112   private final RecordComparator binaryComparator = new RecordBinaryComparator();
0113 
0114   // Compute the most compact size for UnsafeRow's backing data.
0115   private int computeSizeInBytes(int originalSize) {
0116     // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall
0117     // always be 8.
0118     return 8 + (originalSize + 7) / 8 * 8;
0119   }
0120 
0121   // Compute the relative offset of variable-length values.
0122   private long relativeOffset(int numFields) {
0123     // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall
0124     // always be 8.
0125     return 8 + numFields * 8L;
0126   }
0127 
0128   @Test
0129   public void testBinaryComparatorForSingleColumnRow() throws Exception {
0130     int numFields = 1;
0131 
0132     UnsafeRow row1 = new UnsafeRow(numFields);
0133     byte[] data1 = new byte[100];
0134     row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0135     row1.setInt(0, 11);
0136 
0137     UnsafeRow row2 = new UnsafeRow(numFields);
0138     byte[] data2 = new byte[100];
0139     row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0140     row2.setInt(0, 42);
0141 
0142     insertRow(row1);
0143     insertRow(row2);
0144 
0145     Assert.assertEquals(0, compare(0, 0));
0146     Assert.assertTrue(compare(0, 1) < 0);
0147   }
0148 
0149   @Test
0150   public void testBinaryComparatorForMultipleColumnRow() throws Exception {
0151     int numFields = 5;
0152 
0153     UnsafeRow row1 = new UnsafeRow(numFields);
0154     byte[] data1 = new byte[100];
0155     row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0156     for (int i = 0; i < numFields; i++) {
0157       row1.setDouble(i, i * 3.14);
0158     }
0159 
0160     UnsafeRow row2 = new UnsafeRow(numFields);
0161     byte[] data2 = new byte[100];
0162     row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0163     for (int i = 0; i < numFields; i++) {
0164       row2.setDouble(i, 198.7 / (i + 1));
0165     }
0166 
0167     insertRow(row1);
0168     insertRow(row2);
0169 
0170     Assert.assertEquals(0, compare(0, 0));
0171     Assert.assertTrue(compare(0, 1) < 0);
0172   }
0173 
0174   @Test
0175   public void testBinaryComparatorForArrayColumn() throws Exception {
0176     int numFields = 1;
0177 
0178     UnsafeRow row1 = new UnsafeRow(numFields);
0179     byte[] data1 = new byte[100];
0180     UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1});
0181     row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes()));
0182     row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes());
0183     Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1,
0184         row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes());
0185 
0186     UnsafeRow row2 = new UnsafeRow(numFields);
0187     byte[] data2 = new byte[100];
0188     UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22});
0189     row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes()));
0190     row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes());
0191     Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2,
0192         row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes());
0193 
0194     insertRow(row1);
0195     insertRow(row2);
0196 
0197     Assert.assertEquals(0, compare(0, 0));
0198     Assert.assertTrue(compare(0, 1) > 0);
0199   }
0200 
0201   @Test
0202   public void testBinaryComparatorForMixedColumns() throws Exception {
0203     int numFields = 4;
0204 
0205     UnsafeRow row1 = new UnsafeRow(numFields);
0206     byte[] data1 = new byte[100];
0207     UTF8String str1 = UTF8String.fromString("Milk tea");
0208     row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes()));
0209     row1.setInt(0, 11);
0210     row1.setDouble(1, 3.14);
0211     row1.setInt(2, -1);
0212     row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes());
0213     Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1,
0214         row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes());
0215 
0216     UnsafeRow row2 = new UnsafeRow(numFields);
0217     byte[] data2 = new byte[100];
0218     UTF8String str2 = UTF8String.fromString("Java");
0219     row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes()));
0220     row2.setInt(0, 11);
0221     row2.setDouble(1, 3.14);
0222     row2.setInt(2, -1);
0223     row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes());
0224     Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2,
0225         row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes());
0226 
0227     insertRow(row1);
0228     insertRow(row2);
0229 
0230     Assert.assertEquals(0, compare(0, 0));
0231     Assert.assertTrue(compare(0, 1) > 0);
0232   }
0233 
0234   @Test
0235   public void testBinaryComparatorForNullColumns() throws Exception {
0236     int numFields = 3;
0237 
0238     UnsafeRow row1 = new UnsafeRow(numFields);
0239     byte[] data1 = new byte[100];
0240     row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0241     for (int i = 0; i < numFields; i++) {
0242       row1.setNullAt(i);
0243     }
0244 
0245     UnsafeRow row2 = new UnsafeRow(numFields);
0246     byte[] data2 = new byte[100];
0247     row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0248     for (int i = 0; i < numFields - 1; i++) {
0249       row2.setNullAt(i);
0250     }
0251     row2.setDouble(numFields - 1, 3.14);
0252 
0253     insertRow(row1);
0254     insertRow(row2);
0255 
0256     Assert.assertEquals(0, compare(0, 0));
0257     Assert.assertTrue(compare(0, 1) > 0);
0258   }
0259 
0260   @Test
0261   public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception {
0262     int numFields = 1;
0263 
0264     UnsafeRow row1 = new UnsafeRow(numFields);
0265     byte[] data1 = new byte[100];
0266     row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0267     row1.setLong(0, 11);
0268 
0269     UnsafeRow row2 = new UnsafeRow(numFields);
0270     byte[] data2 = new byte[100];
0271     row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0272     row2.setLong(0, 11L + Integer.MAX_VALUE);
0273 
0274     insertRow(row1);
0275     insertRow(row2);
0276 
0277     Assert.assertTrue(compare(0, 1) > 0);
0278   }
0279 
0280   @Test
0281   public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception {
0282     int numFields = 1;
0283 
0284     UnsafeRow row1 = new UnsafeRow(numFields);
0285     byte[] data1 = new byte[100];
0286     row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0287     row1.setLong(0, Long.MIN_VALUE);
0288 
0289     UnsafeRow row2 = new UnsafeRow(numFields);
0290     byte[] data2 = new byte[100];
0291     row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0292     row2.setLong(0, 1);
0293 
0294     insertRow(row1);
0295     insertRow(row2);
0296 
0297     Assert.assertTrue(compare(0, 1) < 0);
0298   }
0299 
0300   @Test
0301   public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception {
0302     int numFields = 4;
0303 
0304     UnsafeRow row1 = new UnsafeRow(numFields);
0305     byte[] data1 = new byte[100];
0306     row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0307     row1.setInt(0, 11);
0308     row1.setDouble(1, 3.14);
0309     row1.setInt(2, -1);
0310     row1.setLong(3, 0);
0311 
0312     UnsafeRow row2 = new UnsafeRow(numFields);
0313     byte[] data2 = new byte[100];
0314     row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0315     row2.setInt(0, 11);
0316     row2.setDouble(1, 3.14);
0317     row2.setInt(2, -1);
0318     row2.setLong(3, 1);
0319 
0320     insertRow(row1);
0321     insertRow(row2);
0322 
0323     Assert.assertTrue(compare(0, 1) < 0);
0324   }
0325 
0326   @Test
0327   public void testCompareLongsAsLittleEndian() {
0328     long arrayOffset = Platform.LONG_ARRAY_OFFSET + 4;
0329 
0330     long[] arr1 = new long[2];
0331     Platform.putLong(arr1, arrayOffset, 0x0100000000000000L);
0332     long[] arr2 = new long[2];
0333     Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000001L);
0334     // leftBaseOffset is not aligned while rightBaseOffset is aligned,
0335     // it will start by comparing long
0336     int result1 = binaryComparator.compare(arr1, arrayOffset, 8, arr2, arrayOffset + 4, 8);
0337 
0338     long[] arr3 = new long[2];
0339     Platform.putLong(arr3, arrayOffset, 0x0100000000000000L);
0340     long[] arr4 = new long[2];
0341     Platform.putLong(arr4, arrayOffset, 0x0000000000000001L);
0342     // both left and right offset is not aligned, it will start with byte-by-byte comparison
0343     int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8);
0344 
0345     Assert.assertEquals(result1, result2);
0346   }
0347 
0348   @Test
0349   public void testCompareLongsAsUnsigned() {
0350     long arrayOffset = Platform.LONG_ARRAY_OFFSET + 4;
0351 
0352     long[] arr1 = new long[2];
0353     Platform.putLong(arr1, arrayOffset + 4, 0xa000000000000000L);
0354     long[] arr2 = new long[2];
0355     Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000000L);
0356     // both leftBaseOffset and rightBaseOffset are aligned, so it will start by comparing long
0357     int result1 = binaryComparator.compare(arr1, arrayOffset + 4, 8, arr2, arrayOffset + 4, 8);
0358 
0359     long[] arr3 = new long[2];
0360     Platform.putLong(arr3, arrayOffset, 0xa000000000000000L);
0361     long[] arr4 = new long[2];
0362     Platform.putLong(arr4, arrayOffset, 0x0000000000000000L);
0363     // both leftBaseOffset and rightBaseOffset are not aligned,
0364     // so it will start with byte-by-byte comparison
0365     int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8);
0366 
0367     Assert.assertEquals(result1, result2);
0368   }
0369 }