Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.sql.catalyst.expressions;
0019 
0020 import org.junit.After;
0021 import org.junit.Assert;
0022 import org.junit.Before;
0023 import org.junit.Test;
0024 
0025 import org.apache.spark.SparkConf;
0026 import org.apache.spark.memory.TaskMemoryManager;
0027 import org.apache.spark.memory.TestMemoryManager;
0028 import org.apache.spark.sql.types.StructType;
0029 import org.apache.spark.sql.types.DataTypes;
0030 import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
0031 import org.apache.spark.unsafe.types.UTF8String;
0032 import org.apache.spark.internal.config.package$;
0033 
0034 import java.util.Random;
0035 
0036 public class RowBasedKeyValueBatchSuite {
0037 
0038   private final Random rand = new Random(42);
0039 
0040   private TestMemoryManager memoryManager;
0041   private TaskMemoryManager taskMemoryManager;
0042   private StructType keySchema = new StructType().add("k1", DataTypes.LongType)
0043           .add("k2", DataTypes.StringType);
0044   private StructType fixedKeySchema = new StructType().add("k1", DataTypes.LongType)
0045           .add("k2", DataTypes.LongType);
0046   private StructType valueSchema = new StructType().add("count", DataTypes.LongType)
0047           .add("sum", DataTypes.LongType);
0048   private int DEFAULT_CAPACITY = 1 << 16;
0049 
0050   private String getRandomString(int length) {
0051     Assert.assertTrue(length >= 0);
0052     final byte[] bytes = new byte[length];
0053     rand.nextBytes(bytes);
0054     return new String(bytes);
0055   }
0056 
0057   private UnsafeRow makeKeyRow(long k1, String k2) {
0058     UnsafeRowWriter writer = new UnsafeRowWriter(2);
0059     writer.reset();
0060     writer.write(0, k1);
0061     writer.write(1, UTF8String.fromString(k2));
0062     return writer.getRow();
0063   }
0064 
0065   private UnsafeRow makeKeyRow(long k1, long k2) {
0066     UnsafeRowWriter writer = new UnsafeRowWriter(2);
0067     writer.reset();
0068     writer.write(0, k1);
0069     writer.write(1, k2);
0070     return writer.getRow();
0071   }
0072 
0073   private UnsafeRow makeValueRow(long v1, long v2) {
0074     UnsafeRowWriter writer = new UnsafeRowWriter(2);
0075     writer.reset();
0076     writer.write(0, v1);
0077     writer.write(1, v2);
0078     return writer.getRow();
0079   }
0080 
0081   private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) {
0082     return batch.appendRow(key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
0083             value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes());
0084   }
0085 
0086   private void updateValueRow(UnsafeRow row, long v1, long v2) {
0087     row.setLong(0, v1);
0088     row.setLong(1, v2);
0089   }
0090 
0091   private boolean checkKey(UnsafeRow row, long k1, String k2) {
0092     return (row.getLong(0) == k1)
0093             && (row.getUTF8String(1).equals(UTF8String.fromString(k2)));
0094   }
0095 
0096   private boolean checkKey(UnsafeRow row, long k1, long k2) {
0097     return (row.getLong(0) == k1)
0098             && (row.getLong(1) == k2);
0099   }
0100 
0101   private boolean checkValue(UnsafeRow row, long v1, long v2) {
0102     return (row.getLong(0) == v1) && (row.getLong(1) == v2);
0103   }
0104 
0105   @Before
0106   public void setup() {
0107     memoryManager = new TestMemoryManager(new SparkConf()
0108             .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)
0109             .set(package$.MODULE$.SHUFFLE_SPILL_COMPRESS(), false)
0110             .set(package$.MODULE$.SHUFFLE_COMPRESS(), false));
0111     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
0112   }
0113 
0114   @After
0115   public void tearDown() {
0116     if (taskMemoryManager != null) {
0117       Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
0118       long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
0119       taskMemoryManager = null;
0120       Assert.assertEquals(0L, leakedMemory);
0121     }
0122   }
0123 
0124 
0125   @Test
0126   public void emptyBatch() throws Exception {
0127     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0128         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0129       Assert.assertEquals(0, batch.numRows());
0130       try {
0131         batch.getKeyRow(-1);
0132         Assert.fail("Should not be able to get row -1");
0133       } catch (AssertionError e) {
0134         // Expected exception; do nothing.
0135       }
0136       try {
0137         batch.getValueRow(-1);
0138         Assert.fail("Should not be able to get row -1");
0139       } catch (AssertionError e) {
0140         // Expected exception; do nothing.
0141       }
0142       try {
0143         batch.getKeyRow(0);
0144         Assert.fail("Should not be able to get row 0 when batch is empty");
0145       } catch (AssertionError e) {
0146         // Expected exception; do nothing.
0147       }
0148       try {
0149         batch.getValueRow(0);
0150         Assert.fail("Should not be able to get row 0 when batch is empty");
0151       } catch (AssertionError e) {
0152         // Expected exception; do nothing.
0153       }
0154       Assert.assertFalse(batch.rowIterator().next());
0155     }
0156   }
0157 
0158   @Test
0159   public void batchType() {
0160     try (RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema,
0161         valueSchema, taskMemoryManager, DEFAULT_CAPACITY);
0162          RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema,
0163         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0164       Assert.assertEquals(VariableLengthRowBasedKeyValueBatch.class, batch1.getClass());
0165       Assert.assertEquals(FixedLengthRowBasedKeyValueBatch.class, batch2.getClass());
0166     }
0167   }
0168 
0169   @Test
0170   public void setAndRetrieve() {
0171     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0172         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0173       UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
0174       Assert.assertTrue(checkValue(ret1, 1, 1));
0175       UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
0176       Assert.assertTrue(checkValue(ret2, 2, 2));
0177       UnsafeRow ret3 = appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
0178       Assert.assertTrue(checkValue(ret3, 3, 3));
0179       Assert.assertEquals(3, batch.numRows());
0180       UnsafeRow retrievedKey1 = batch.getKeyRow(0);
0181       Assert.assertTrue(checkKey(retrievedKey1, 1, "A"));
0182       UnsafeRow retrievedKey2 = batch.getKeyRow(1);
0183       Assert.assertTrue(checkKey(retrievedKey2, 2, "B"));
0184       UnsafeRow retrievedValue1 = batch.getValueRow(1);
0185       Assert.assertTrue(checkValue(retrievedValue1, 2, 2));
0186       UnsafeRow retrievedValue2 = batch.getValueRow(2);
0187       Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
0188       try {
0189         batch.getKeyRow(3);
0190         Assert.fail("Should not be able to get row 3");
0191       } catch (AssertionError e) {
0192         // Expected exception; do nothing.
0193       }
0194       try {
0195         batch.getValueRow(3);
0196         Assert.fail("Should not be able to get row 3");
0197       } catch (AssertionError e) {
0198         // Expected exception; do nothing.
0199       }
0200     }
0201   }
0202 
0203   @Test
0204   public void setUpdateAndRetrieve() {
0205     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0206         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0207       appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
0208       Assert.assertEquals(1, batch.numRows());
0209       UnsafeRow retrievedValue = batch.getValueRow(0);
0210       updateValueRow(retrievedValue, 2, 2);
0211       UnsafeRow retrievedValue2 = batch.getValueRow(0);
0212       Assert.assertTrue(checkValue(retrievedValue2, 2, 2));
0213     }
0214   }
0215 
0216 
0217   @Test
0218   public void iteratorTest() throws Exception {
0219     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0220         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0221       appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1));
0222       appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2));
0223       appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3));
0224       Assert.assertEquals(3, batch.numRows());
0225       org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0226               = batch.rowIterator();
0227       Assert.assertTrue(iterator.next());
0228       UnsafeRow key1 = iterator.getKey();
0229       UnsafeRow value1 = iterator.getValue();
0230       Assert.assertTrue(checkKey(key1, 1, "A"));
0231       Assert.assertTrue(checkValue(value1, 1, 1));
0232       Assert.assertTrue(iterator.next());
0233       UnsafeRow key2 = iterator.getKey();
0234       UnsafeRow value2 = iterator.getValue();
0235       Assert.assertTrue(checkKey(key2, 2, "B"));
0236       Assert.assertTrue(checkValue(value2, 2, 2));
0237       Assert.assertTrue(iterator.next());
0238       UnsafeRow key3 = iterator.getKey();
0239       UnsafeRow value3 = iterator.getValue();
0240       Assert.assertTrue(checkKey(key3, 3, "C"));
0241       Assert.assertTrue(checkValue(value3, 3, 3));
0242       Assert.assertFalse(iterator.next());
0243     }
0244   }
0245 
0246   @Test
0247   public void fixedLengthTest() throws Exception {
0248     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema,
0249         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0250       appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1));
0251       appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2));
0252       appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3));
0253       UnsafeRow retrievedKey1 = batch.getKeyRow(0);
0254       Assert.assertTrue(checkKey(retrievedKey1, 11, 11));
0255       UnsafeRow retrievedKey2 = batch.getKeyRow(1);
0256       Assert.assertTrue(checkKey(retrievedKey2, 22, 22));
0257       UnsafeRow retrievedValue1 = batch.getValueRow(1);
0258       Assert.assertTrue(checkValue(retrievedValue1, 2, 2));
0259       UnsafeRow retrievedValue2 = batch.getValueRow(2);
0260       Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
0261       Assert.assertEquals(3, batch.numRows());
0262       org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0263               = batch.rowIterator();
0264       Assert.assertTrue(iterator.next());
0265       UnsafeRow key1 = iterator.getKey();
0266       UnsafeRow value1 = iterator.getValue();
0267       Assert.assertTrue(checkKey(key1, 11, 11));
0268       Assert.assertTrue(checkValue(value1, 1, 1));
0269       Assert.assertTrue(iterator.next());
0270       UnsafeRow key2 = iterator.getKey();
0271       UnsafeRow value2 = iterator.getValue();
0272       Assert.assertTrue(checkKey(key2, 22, 22));
0273       Assert.assertTrue(checkValue(value2, 2, 2));
0274       Assert.assertTrue(iterator.next());
0275       UnsafeRow key3 = iterator.getKey();
0276       UnsafeRow value3 = iterator.getValue();
0277       Assert.assertTrue(checkKey(key3, 33, 33));
0278       Assert.assertTrue(checkValue(value3, 3, 3));
0279       Assert.assertFalse(iterator.next());
0280     }
0281   }
0282 
0283   @Test
0284   public void appendRowUntilExceedingCapacity() throws Exception {
0285     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0286         valueSchema, taskMemoryManager, 10)) {
0287       UnsafeRow key = makeKeyRow(1, "A");
0288       UnsafeRow value = makeValueRow(1, 1);
0289       for (int i = 0; i < 10; i++) {
0290         appendRow(batch, key, value);
0291       }
0292       UnsafeRow ret = appendRow(batch, key, value);
0293       Assert.assertEquals(10, batch.numRows());
0294       Assert.assertNull(ret);
0295       org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0296               = batch.rowIterator();
0297       for (int i = 0; i < 10; i++) {
0298         Assert.assertTrue(iterator.next());
0299         UnsafeRow key1 = iterator.getKey();
0300         UnsafeRow value1 = iterator.getValue();
0301         Assert.assertTrue(checkKey(key1, 1, "A"));
0302         Assert.assertTrue(checkValue(value1, 1, 1));
0303       }
0304       Assert.assertFalse(iterator.next());
0305     }
0306   }
0307 
0308   @Test
0309   public void appendRowUntilExceedingPageSize() throws Exception {
0310     // Use default size or spark.buffer.pageSize if specified
0311     int pageSizeToUse = (int) memoryManager.pageSizeBytes();
0312     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0313         valueSchema, taskMemoryManager, pageSizeToUse)) {
0314       UnsafeRow key = makeKeyRow(1, "A");
0315       UnsafeRow value = makeValueRow(1, 1);
0316       int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8;
0317       int totalSize = 4;
0318       int numRows = 0;
0319       while (totalSize + recordLength < pageSizeToUse) {
0320         appendRow(batch, key, value);
0321         totalSize += recordLength;
0322         numRows++;
0323       }
0324       UnsafeRow ret = appendRow(batch, key, value);
0325       Assert.assertEquals(numRows, batch.numRows());
0326       Assert.assertNull(ret);
0327       org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> iterator
0328               = batch.rowIterator();
0329       for (int i = 0; i < numRows; i++) {
0330         Assert.assertTrue(iterator.next());
0331         UnsafeRow key1 = iterator.getKey();
0332         UnsafeRow value1 = iterator.getValue();
0333         Assert.assertTrue(checkKey(key1, 1, "A"));
0334         Assert.assertTrue(checkValue(value1, 1, 1));
0335       }
0336       Assert.assertFalse(iterator.next());
0337     }
0338   }
0339 
0340   @Test
0341   public void failureToAllocateFirstPage() throws Exception {
0342     memoryManager.limit(1024);
0343     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0344         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0345       UnsafeRow key = makeKeyRow(1, "A");
0346       UnsafeRow value = makeValueRow(11, 11);
0347       UnsafeRow ret = appendRow(batch, key, value);
0348       Assert.assertNull(ret);
0349       Assert.assertFalse(batch.rowIterator().next());
0350     }
0351   }
0352 
0353   @Test
0354   public void randomizedTest() {
0355     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
0356         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
0357       int numEntry = 100;
0358       long[] expectedK1 = new long[numEntry];
0359       String[] expectedK2 = new String[numEntry];
0360       long[] expectedV1 = new long[numEntry];
0361       long[] expectedV2 = new long[numEntry];
0362 
0363       for (int i = 0; i < numEntry; i++) {
0364         long k1 = rand.nextLong();
0365         String k2 = getRandomString(rand.nextInt(256));
0366         long v1 = rand.nextLong();
0367         long v2 = rand.nextLong();
0368         appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2));
0369         expectedK1[i] = k1;
0370         expectedK2[i] = k2;
0371         expectedV1[i] = v1;
0372         expectedV2[i] = v2;
0373       }
0374 
0375       for (int j = 0; j < 10000; j++) {
0376         int rowId = rand.nextInt(numEntry);
0377         if (rand.nextBoolean()) {
0378           UnsafeRow key = batch.getKeyRow(rowId);
0379           Assert.assertTrue(checkKey(key, expectedK1[rowId], expectedK2[rowId]));
0380         }
0381         if (rand.nextBoolean()) {
0382           UnsafeRow value = batch.getValueRow(rowId);
0383           Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId]));
0384         }
0385       }
0386     }
0387   }
0388 }