0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql.execution.sort;
0019
0020 import org.apache.spark.SparkConf;
0021 import org.apache.spark.internal.config.package$;
0022 import org.apache.spark.memory.TaskMemoryManager;
0023 import org.apache.spark.memory.TestMemoryConsumer;
0024 import org.apache.spark.memory.TestMemoryManager;
0025 import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
0026 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
0027 import org.apache.spark.sql.execution.RecordBinaryComparator;
0028 import org.apache.spark.unsafe.Platform;
0029 import org.apache.spark.unsafe.UnsafeAlignedOffset;
0030 import org.apache.spark.unsafe.array.LongArray;
0031 import org.apache.spark.unsafe.memory.MemoryBlock;
0032 import org.apache.spark.unsafe.types.UTF8String;
0033 import org.apache.spark.util.collection.unsafe.sort.*;
0034
0035 import org.junit.After;
0036 import org.junit.Assert;
0037 import org.junit.Before;
0038 import org.junit.Test;
0039
0040
0041
0042
0043 public class RecordBinaryComparatorSuite {
0044
0045 private final TaskMemoryManager memoryManager = new TaskMemoryManager(
0046 new TestMemoryManager(
0047 new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)), 0);
0048 private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
0049
0050 private final int uaoSize = UnsafeAlignedOffset.getUaoSize();
0051
0052 private MemoryBlock dataPage;
0053 private long pageCursor;
0054
0055 private LongArray array;
0056 private int pos;
0057
0058 @Before
0059 public void beforeEach() {
0060
0061 array = consumer.allocateArray(2);
0062 pos = 0;
0063
0064 dataPage = memoryManager.allocatePage(4096, consumer);
0065 pageCursor = dataPage.getBaseOffset();
0066 }
0067
0068 @After
0069 public void afterEach() {
0070 consumer.freePage(dataPage);
0071 dataPage = null;
0072 pageCursor = 0;
0073
0074 consumer.freeArray(array);
0075 array = null;
0076 pos = 0;
0077 }
0078
0079 private void insertRow(UnsafeRow row) {
0080 Object recordBase = row.getBaseObject();
0081 long recordOffset = row.getBaseOffset();
0082 int recordLength = row.getSizeInBytes();
0083
0084 Object baseObject = dataPage.getBaseObject();
0085 Assert.assertTrue(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size());
0086 long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor);
0087 UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
0088 pageCursor += uaoSize;
0089 Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength);
0090 pageCursor += recordLength;
0091
0092 Assert.assertTrue(pos < 2);
0093 array.set(pos, recordAddress);
0094 pos++;
0095 }
0096
0097 private int compare(int index1, int index2) {
0098 Object baseObject = dataPage.getBaseObject();
0099
0100 long recordAddress1 = array.get(index1);
0101 long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize;
0102 int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize);
0103
0104 long recordAddress2 = array.get(index2);
0105 long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize;
0106 int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize);
0107
0108 return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject,
0109 baseOffset2, recordLength2);
0110 }
0111
0112 private final RecordComparator binaryComparator = new RecordBinaryComparator();
0113
0114
0115 private int computeSizeInBytes(int originalSize) {
0116
0117
0118 return 8 + (originalSize + 7) / 8 * 8;
0119 }
0120
0121
0122 private long relativeOffset(int numFields) {
0123
0124
0125 return 8 + numFields * 8L;
0126 }
0127
0128 @Test
0129 public void testBinaryComparatorForSingleColumnRow() throws Exception {
0130 int numFields = 1;
0131
0132 UnsafeRow row1 = new UnsafeRow(numFields);
0133 byte[] data1 = new byte[100];
0134 row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0135 row1.setInt(0, 11);
0136
0137 UnsafeRow row2 = new UnsafeRow(numFields);
0138 byte[] data2 = new byte[100];
0139 row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0140 row2.setInt(0, 42);
0141
0142 insertRow(row1);
0143 insertRow(row2);
0144
0145 Assert.assertEquals(0, compare(0, 0));
0146 Assert.assertTrue(compare(0, 1) < 0);
0147 }
0148
0149 @Test
0150 public void testBinaryComparatorForMultipleColumnRow() throws Exception {
0151 int numFields = 5;
0152
0153 UnsafeRow row1 = new UnsafeRow(numFields);
0154 byte[] data1 = new byte[100];
0155 row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0156 for (int i = 0; i < numFields; i++) {
0157 row1.setDouble(i, i * 3.14);
0158 }
0159
0160 UnsafeRow row2 = new UnsafeRow(numFields);
0161 byte[] data2 = new byte[100];
0162 row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0163 for (int i = 0; i < numFields; i++) {
0164 row2.setDouble(i, 198.7 / (i + 1));
0165 }
0166
0167 insertRow(row1);
0168 insertRow(row2);
0169
0170 Assert.assertEquals(0, compare(0, 0));
0171 Assert.assertTrue(compare(0, 1) < 0);
0172 }
0173
0174 @Test
0175 public void testBinaryComparatorForArrayColumn() throws Exception {
0176 int numFields = 1;
0177
0178 UnsafeRow row1 = new UnsafeRow(numFields);
0179 byte[] data1 = new byte[100];
0180 UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1});
0181 row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes()));
0182 row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes());
0183 Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1,
0184 row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes());
0185
0186 UnsafeRow row2 = new UnsafeRow(numFields);
0187 byte[] data2 = new byte[100];
0188 UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22});
0189 row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes()));
0190 row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes());
0191 Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2,
0192 row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes());
0193
0194 insertRow(row1);
0195 insertRow(row2);
0196
0197 Assert.assertEquals(0, compare(0, 0));
0198 Assert.assertTrue(compare(0, 1) > 0);
0199 }
0200
0201 @Test
0202 public void testBinaryComparatorForMixedColumns() throws Exception {
0203 int numFields = 4;
0204
0205 UnsafeRow row1 = new UnsafeRow(numFields);
0206 byte[] data1 = new byte[100];
0207 UTF8String str1 = UTF8String.fromString("Milk tea");
0208 row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes()));
0209 row1.setInt(0, 11);
0210 row1.setDouble(1, 3.14);
0211 row1.setInt(2, -1);
0212 row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes());
0213 Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1,
0214 row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes());
0215
0216 UnsafeRow row2 = new UnsafeRow(numFields);
0217 byte[] data2 = new byte[100];
0218 UTF8String str2 = UTF8String.fromString("Java");
0219 row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes()));
0220 row2.setInt(0, 11);
0221 row2.setDouble(1, 3.14);
0222 row2.setInt(2, -1);
0223 row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes());
0224 Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2,
0225 row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes());
0226
0227 insertRow(row1);
0228 insertRow(row2);
0229
0230 Assert.assertEquals(0, compare(0, 0));
0231 Assert.assertTrue(compare(0, 1) > 0);
0232 }
0233
0234 @Test
0235 public void testBinaryComparatorForNullColumns() throws Exception {
0236 int numFields = 3;
0237
0238 UnsafeRow row1 = new UnsafeRow(numFields);
0239 byte[] data1 = new byte[100];
0240 row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0241 for (int i = 0; i < numFields; i++) {
0242 row1.setNullAt(i);
0243 }
0244
0245 UnsafeRow row2 = new UnsafeRow(numFields);
0246 byte[] data2 = new byte[100];
0247 row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0248 for (int i = 0; i < numFields - 1; i++) {
0249 row2.setNullAt(i);
0250 }
0251 row2.setDouble(numFields - 1, 3.14);
0252
0253 insertRow(row1);
0254 insertRow(row2);
0255
0256 Assert.assertEquals(0, compare(0, 0));
0257 Assert.assertTrue(compare(0, 1) > 0);
0258 }
0259
0260 @Test
0261 public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception {
0262 int numFields = 1;
0263
0264 UnsafeRow row1 = new UnsafeRow(numFields);
0265 byte[] data1 = new byte[100];
0266 row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0267 row1.setLong(0, 11);
0268
0269 UnsafeRow row2 = new UnsafeRow(numFields);
0270 byte[] data2 = new byte[100];
0271 row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0272 row2.setLong(0, 11L + Integer.MAX_VALUE);
0273
0274 insertRow(row1);
0275 insertRow(row2);
0276
0277 Assert.assertTrue(compare(0, 1) > 0);
0278 }
0279
0280 @Test
0281 public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception {
0282 int numFields = 1;
0283
0284 UnsafeRow row1 = new UnsafeRow(numFields);
0285 byte[] data1 = new byte[100];
0286 row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0287 row1.setLong(0, Long.MIN_VALUE);
0288
0289 UnsafeRow row2 = new UnsafeRow(numFields);
0290 byte[] data2 = new byte[100];
0291 row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0292 row2.setLong(0, 1);
0293
0294 insertRow(row1);
0295 insertRow(row2);
0296
0297 Assert.assertTrue(compare(0, 1) < 0);
0298 }
0299
0300 @Test
0301 public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception {
0302 int numFields = 4;
0303
0304 UnsafeRow row1 = new UnsafeRow(numFields);
0305 byte[] data1 = new byte[100];
0306 row1.pointTo(data1, computeSizeInBytes(numFields * 8));
0307 row1.setInt(0, 11);
0308 row1.setDouble(1, 3.14);
0309 row1.setInt(2, -1);
0310 row1.setLong(3, 0);
0311
0312 UnsafeRow row2 = new UnsafeRow(numFields);
0313 byte[] data2 = new byte[100];
0314 row2.pointTo(data2, computeSizeInBytes(numFields * 8));
0315 row2.setInt(0, 11);
0316 row2.setDouble(1, 3.14);
0317 row2.setInt(2, -1);
0318 row2.setLong(3, 1);
0319
0320 insertRow(row1);
0321 insertRow(row2);
0322
0323 Assert.assertTrue(compare(0, 1) < 0);
0324 }
0325
0326 @Test
0327 public void testCompareLongsAsLittleEndian() {
0328 long arrayOffset = Platform.LONG_ARRAY_OFFSET + 4;
0329
0330 long[] arr1 = new long[2];
0331 Platform.putLong(arr1, arrayOffset, 0x0100000000000000L);
0332 long[] arr2 = new long[2];
0333 Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000001L);
0334
0335
0336 int result1 = binaryComparator.compare(arr1, arrayOffset, 8, arr2, arrayOffset + 4, 8);
0337
0338 long[] arr3 = new long[2];
0339 Platform.putLong(arr3, arrayOffset, 0x0100000000000000L);
0340 long[] arr4 = new long[2];
0341 Platform.putLong(arr4, arrayOffset, 0x0000000000000001L);
0342
0343 int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8);
0344
0345 Assert.assertEquals(result1, result2);
0346 }
0347
0348 @Test
0349 public void testCompareLongsAsUnsigned() {
0350 long arrayOffset = Platform.LONG_ARRAY_OFFSET + 4;
0351
0352 long[] arr1 = new long[2];
0353 Platform.putLong(arr1, arrayOffset + 4, 0xa000000000000000L);
0354 long[] arr2 = new long[2];
0355 Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000000L);
0356
0357 int result1 = binaryComparator.compare(arr1, arrayOffset + 4, 8, arr2, arrayOffset + 4, 8);
0358
0359 long[] arr3 = new long[2];
0360 Platform.putLong(arr3, arrayOffset, 0xa000000000000000L);
0361 long[] arr4 = new long[2];
0362 Platform.putLong(arr4, arrayOffset, 0x0000000000000000L);
0363
0364
0365 int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8);
0366
0367 Assert.assertEquals(result1, result2);
0368 }
0369 }