0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.sql.catalyst.expressions;
0019
0020 import org.junit.After;
0021 import org.junit.Assert;
0022 import org.junit.Before;
0023 import org.junit.Test;
0024
0025 import org.apache.spark.SparkConf;
0026 import org.apache.spark.memory.TaskMemoryManager;
0027 import org.apache.spark.memory.TestMemoryManager;
0028 import org.apache.spark.sql.types.StructType;
0029 import org.apache.spark.sql.types.DataTypes;
0030 import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
0031 import org.apache.spark.unsafe.types.UTF8String;
0032 import org.apache.spark.internal.config.package$;
0033
0034 import java.util.Random;
0035
0036 public class RowBasedKeyValueBatchSuite {
0037
0038 private final Random rand = new Random(42);
0039
0040 private TestMemoryManager memoryManager;
0041 private TaskMemoryManager taskMemoryManager;
0042 private StructType keySchema = new StructType().add("k1", DataTypes.LongType)
0043 .add("k2", DataTypes.StringType);
0044 private StructType fixedKeySchema = new StructType().add("k1", DataTypes.LongType)
0045 .add("k2", DataTypes.LongType);
0046 private StructType valueSchema = new StructType().add("count", DataTypes.LongType)
0047 .add("sum", DataTypes.LongType);
0048 private int DEFAULT_CAPACITY = 1 << 16;
0049
0050 private String getRandomString(int length) {
0051 Assert.assertTrue(length >= 0);
0052 final byte[] bytes = new byte[length];
0053 rand.nextBytes(bytes);
0054 return new String(bytes);
0055 }
0056
0057 private UnsafeRow makeKeyRow(long k1, String k2) {
0058 UnsafeRowWriter writer = new UnsafeRowWriter(2);
0059 writer.reset();
0060 writer.write(0, k1);
0061 writer.write(1, UTF8String.fromString(k2));
0062 return writer.getRow();
0063 }
0064
0065 private UnsafeRow makeKeyRow(long k1, long k2) {
0066 UnsafeRowWriter writer = new UnsafeRowWriter(2);
0067 writer.reset();
0068 writer.write(0, k1);
0069 writer.write(1, k2);
0070 return writer.getRow();
0071 }
0072
0073 private UnsafeRow makeValueRow(long v1, long v2) {
0074 UnsafeRowWriter writer = new UnsafeRowWriter(2);
0075 writer.reset();
0076 writer.write(0, v1);
0077 writer.write(1, v2);
0078 return writer.getRow();
0079 }
0080
0081 private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) {
0082 return batch.appendRow(key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
0083 value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes());
0084 }
0085
0086 private void updateValueRow(UnsafeRow row, long v1, long v2) {
0087 row.setLong(0, v1);
0088 row.setLong(1, v2);
0089 }
0090
0091 private boolean checkKey(UnsafeRow row, long k1, String k2) {
0092 return (row.getLong(0) == k1)
0093 && (row.getUTF8String(1).equals(UTF8String.fromString(k2)));
0094 }
0095
0096 private boolean checkKey(UnsafeRow row, long k1, long k2) {
0097 return (row.getLong(0) == k1)
0098 && (row.getLong(1) == k2);
0099 }
0100
0101 private boolean checkValue(UnsafeRow row, long v1, long v2) {
0102 return (row.getLong(0) == v1) && (row.getLong(1) == v2);
0103 }
0104
0105 @Before
0106 public void setup() {
0107 memoryManager = new TestMemoryManager(new SparkConf()
0108 .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)
0109 .set(package$.MODULE$.SHUFFLE_SPILL_COMPRESS(), false)
0110 .set(package$.MODULE$.SHUFFLE_COMPRESS(), false));
0111 taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0112 }
0113
0114 @After
0115 public void tearDown() {
0116 if (taskMemoryManager != null) {
0117 Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
0118 long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
0119 taskMemoryManager = null;
0120 Assert.assertEquals(0L, leakedMemory);
0121 }
0122 }
0123
0124
0125 @Test
0126 public void emptyBatch() throws Exception {
0127 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0128 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0129 Assert.assertEquals(0, batch.numRows());
0130 try {
0131 batch.getKeyRow(-1);
0132 Assert.fail("Should not be able to get row -1");
0133 } catch (AssertionError e) {
0134
0135 }
0136 try {
0137 batch.getValueRow(-1);
0138 Assert.fail("Should not be able to get row -1");
0139 } catch (AssertionError e) {
0140
0141 }
0142 try {
0143 batch.getKeyRow(0);
0144 Assert.fail("Should not be able to get row 0 when batch is empty");
0145 } catch (AssertionError e) {
0146
0147 }
0148 try {
0149 batch.getValueRow(0);
0150 Assert.fail("Should not be able to get row 0 when batch is empty");
0151 } catch (AssertionError e) {
0152
0153 }
0154 Assert.assertFalse(batch.rowIterator().next());
0155 }
0156 }
0157
0158 @Test
0159 public void batchType() {
0160 try (RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema,
0161 valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
0162 RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema,
0163 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0164 Assert.assertEquals(VariableLengthRowBasedKeyValueBatch.class, batch1.getClass());
0165 Assert.assertEquals(FixedLengthRowBasedKeyValueBatch.class, batch2.getClass());
0166 }
0167 }
0168
0169 @Test
0170 public void setAndRetrieve() {
0171 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0172 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0173 UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
0174 Assert.assertTrue(checkValue(ret1, 1, 1));
0175 UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
0176 Assert.assertTrue(checkValue(ret2, 2, 2));
0177 UnsafeRow ret3 = appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
0178 Assert.assertTrue(checkValue(ret3, 3, 3));
0179 Assert.assertEquals(3, batch.numRows());
0180 UnsafeRow retrievedKey1 = batch.getKeyRow(0);
0181 Assert.assertTrue(checkKey(retrievedKey1, 1, "A"));
0182 UnsafeRow retrievedKey2 = batch.getKeyRow(1);
0183 Assert.assertTrue(checkKey(retrievedKey2, 2, "B"));
0184 UnsafeRow retrievedValue1 = batch.getValueRow(1);
0185 Assert.assertTrue(checkValue(retrievedValue1, 2, 2));
0186 UnsafeRow retrievedValue2 = batch.getValueRow(2);
0187 Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
0188 try {
0189 batch.getKeyRow(3);
0190 Assert.fail("Should not be able to get row 3");
0191 } catch (AssertionError e) {
0192
0193 }
0194 try {
0195 batch.getValueRow(3);
0196 Assert.fail("Should not be able to get row 3");
0197 } catch (AssertionError e) {
0198
0199 }
0200 }
0201 }
0202
0203 @Test
0204 public void setUpdateAndRetrieve() {
0205 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0206 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0207 appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
0208 Assert.assertEquals(1, batch.numRows());
0209 UnsafeRow retrievedValue = batch.getValueRow(0);
0210 updateValueRow(retrievedValue, 2, 2);
0211 UnsafeRow retrievedValue2 = batch.getValueRow(0);
0212 Assert.assertTrue(checkValue(retrievedValue2, 2, 2));
0213 }
0214 }
0215
0216
0217 @Test
0218 public void iteratorTest() throws Exception {
0219 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0220 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0221 appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
0222 appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
0223 appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
0224 Assert.assertEquals(3, batch.numRows());
0225 org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0226 = batch.rowIterator();
0227 Assert.assertTrue(iterator.next());
0228 UnsafeRow key1 = iterator.getKey();
0229 UnsafeRow value1 = iterator.getValue();
0230 Assert.assertTrue(checkKey(key1, 1, "A"));
0231 Assert.assertTrue(checkValue(value1, 1, 1));
0232 Assert.assertTrue(iterator.next());
0233 UnsafeRow key2 = iterator.getKey();
0234 UnsafeRow value2 = iterator.getValue();
0235 Assert.assertTrue(checkKey(key2, 2, "B"));
0236 Assert.assertTrue(checkValue(value2, 2, 2));
0237 Assert.assertTrue(iterator.next());
0238 UnsafeRow key3 = iterator.getKey();
0239 UnsafeRow value3 = iterator.getValue();
0240 Assert.assertTrue(checkKey(key3, 3, "C"));
0241 Assert.assertTrue(checkValue(value3, 3, 3));
0242 Assert.assertFalse(iterator.next());
0243 }
0244 }
0245
0246 @Test
0247 public void fixedLengthTest() throws Exception {
0248 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema,
0249 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0250 appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1));
0251 appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2));
0252 appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3));
0253 UnsafeRow retrievedKey1 = batch.getKeyRow(0);
0254 Assert.assertTrue(checkKey(retrievedKey1, 11, 11));
0255 UnsafeRow retrievedKey2 = batch.getKeyRow(1);
0256 Assert.assertTrue(checkKey(retrievedKey2, 22, 22));
0257 UnsafeRow retrievedValue1 = batch.getValueRow(1);
0258 Assert.assertTrue(checkValue(retrievedValue1, 2, 2));
0259 UnsafeRow retrievedValue2 = batch.getValueRow(2);
0260 Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
0261 Assert.assertEquals(3, batch.numRows());
0262 org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0263 = batch.rowIterator();
0264 Assert.assertTrue(iterator.next());
0265 UnsafeRow key1 = iterator.getKey();
0266 UnsafeRow value1 = iterator.getValue();
0267 Assert.assertTrue(checkKey(key1, 11, 11));
0268 Assert.assertTrue(checkValue(value1, 1, 1));
0269 Assert.assertTrue(iterator.next());
0270 UnsafeRow key2 = iterator.getKey();
0271 UnsafeRow value2 = iterator.getValue();
0272 Assert.assertTrue(checkKey(key2, 22, 22));
0273 Assert.assertTrue(checkValue(value2, 2, 2));
0274 Assert.assertTrue(iterator.next());
0275 UnsafeRow key3 = iterator.getKey();
0276 UnsafeRow value3 = iterator.getValue();
0277 Assert.assertTrue(checkKey(key3, 33, 33));
0278 Assert.assertTrue(checkValue(value3, 3, 3));
0279 Assert.assertFalse(iterator.next());
0280 }
0281 }
0282
0283 @Test
0284 public void appendRowUntilExceedingCapacity() throws Exception {
0285 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0286 valueSchema, taskMemoryManager, 10)) {
0287 UnsafeRow key = makeKeyRow(1, "A");
0288 UnsafeRow value = makeValueRow(1, 1);
0289 for (int i = 0; i < 10; i++) {
0290 appendRow(batch, key, value);
0291 }
0292 UnsafeRow ret = appendRow(batch, key, value);
0293 Assert.assertEquals(10, batch.numRows());
0294 Assert.assertNull(ret);
0295 org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0296 = batch.rowIterator();
0297 for (int i = 0; i < 10; i++) {
0298 Assert.assertTrue(iterator.next());
0299 UnsafeRow key1 = iterator.getKey();
0300 UnsafeRow value1 = iterator.getValue();
0301 Assert.assertTrue(checkKey(key1, 1, "A"));
0302 Assert.assertTrue(checkValue(value1, 1, 1));
0303 }
0304 Assert.assertFalse(iterator.next());
0305 }
0306 }
0307
0308 @Test
0309 public void appendRowUntilExceedingPageSize() throws Exception {
0310
0311 int pageSizeToUse = (int) memoryManager.pageSizeBytes();
0312 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0313 valueSchema, taskMemoryManager, pageSizeToUse)) {
0314 UnsafeRow key = makeKeyRow(1, "A");
0315 UnsafeRow value = makeValueRow(1, 1);
0316 int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8;
0317 int totalSize = 4;
0318 int numRows = 0;
0319 while (totalSize + recordLength < pageSizeToUse) {
0320 appendRow(batch, key, value);
0321 totalSize += recordLength;
0322 numRows++;
0323 }
0324 UnsafeRow ret = appendRow(batch, key, value);
0325 Assert.assertEquals(numRows, batch.numRows());
0326 Assert.assertNull(ret);
0327 org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0328 = batch.rowIterator();
0329 for (int i = 0; i < numRows; i++) {
0330 Assert.assertTrue(iterator.next());
0331 UnsafeRow key1 = iterator.getKey();
0332 UnsafeRow value1 = iterator.getValue();
0333 Assert.assertTrue(checkKey(key1, 1, "A"));
0334 Assert.assertTrue(checkValue(value1, 1, 1));
0335 }
0336 Assert.assertFalse(iterator.next());
0337 }
0338 }
0339
0340 @Test
0341 public void failureToAllocateFirstPage() throws Exception {
0342 memoryManager.limit(1024);
0343 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0344 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0345 UnsafeRow key = makeKeyRow(1, "A");
0346 UnsafeRow value = makeValueRow(11, 11);
0347 UnsafeRow ret = appendRow(batch, key, value);
0348 Assert.assertNull(ret);
0349 Assert.assertFalse(batch.rowIterator().next());
0350 }
0351 }
0352
0353 @Test
0354 public void randomizedTest() {
0355 try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0356 valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0357 int numEntry = 100;
0358 long[] expectedK1 = new long[numEntry];
0359 String[] expectedK2 = new String[numEntry];
0360 long[] expectedV1 = new long[numEntry];
0361 long[] expectedV2 = new long[numEntry];
0362
0363 for (int i = 0; i < numEntry; i++) {
0364 long k1 = rand.nextLong();
0365 String k2 = getRandomString(rand.nextInt(256));
0366 long v1 = rand.nextLong();
0367 long v2 = rand.nextLong();
0368 appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2));
0369 expectedK1[i] = k1;
0370 expectedK2[i] = k2;
0371 expectedV1[i] = v1;
0372 expectedV2[i] = v2;
0373 }
0374
0375 for (int j = 0; j < 10000; j++) {
0376 int rowId = rand.nextInt(numEntry);
0377 if (rand.nextBoolean()) {
0378 UnsafeRow key = batch.getKeyRow(rowId);
0379 Assert.assertTrue(checkKey(key, expectedK1[rowId], expectedK2[rowId]));
0380 }
0381 if (rand.nextBoolean()) {
0382 UnsafeRow value = batch.getValueRow(rowId);
0383 Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId]));
0384 }
0385 }
0386 }
0387 }
0388 }