0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.unsafe.map;
0019
0020 import javax.annotation.Nullable;
0021 import java.io.File;
0022 import java.io.IOException;
0023 import java.util.Iterator;
0024 import java.util.LinkedList;
0025
0026 import com.google.common.annotations.VisibleForTesting;
0027 import com.google.common.io.Closeables;
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030
0031 import org.apache.spark.SparkEnv;
0032 import org.apache.spark.executor.ShuffleWriteMetrics;
0033 import org.apache.spark.memory.MemoryConsumer;
0034 import org.apache.spark.memory.SparkOutOfMemoryError;
0035 import org.apache.spark.memory.TaskMemoryManager;
0036 import org.apache.spark.serializer.SerializerManager;
0037 import org.apache.spark.storage.BlockManager;
0038 import org.apache.spark.unsafe.Platform;
0039 import org.apache.spark.unsafe.UnsafeAlignedOffset;
0040 import org.apache.spark.unsafe.array.ByteArrayMethods;
0041 import org.apache.spark.unsafe.array.LongArray;
0042 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
0043 import org.apache.spark.unsafe.memory.MemoryBlock;
0044 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
0045 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067 public final class BytesToBytesMap extends MemoryConsumer {
0068
0069 private static final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
0070
0071 private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
0072
0073 private final TaskMemoryManager taskMemoryManager;
0074
0075
0076
0077
0078 private final LinkedList<MemoryBlock> dataPages = new LinkedList<>();
0079
0080
0081
0082
0083
0084
0085 private MemoryBlock currentPage = null;
0086
0087
0088
0089
0090
0091 private long pageCursor = 0;
0092
0093
0094
0095
0096
0097
0098
0099 public static final int MAX_CAPACITY = (1 << 29);
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110 @Nullable private LongArray longArray;
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124 private boolean canGrowArray = true;
0125
0126 private final double loadFactor;
0127
0128
0129
0130
0131
0132 private final long pageSizeBytes;
0133
0134
0135
0136
0137 private int numKeys;
0138
0139
0140
0141
0142 private int numValues;
0143
0144
0145
0146
0147 private int growthThreshold;
0148
0149
0150
0151
0152
0153
0154 private int mask;
0155
0156
0157
0158
0159 private final Location loc;
0160
0161 private long numProbes = 0L;
0162
0163 private long numKeyLookups = 0L;
0164
0165 private long peakMemoryUsedBytes = 0L;
0166
0167 private final int initialCapacity;
0168
0169 private final BlockManager blockManager;
0170 private final SerializerManager serializerManager;
0171 private volatile MapIterator destructiveIterator = null;
0172 private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
0173
0174 public BytesToBytesMap(
0175 TaskMemoryManager taskMemoryManager,
0176 BlockManager blockManager,
0177 SerializerManager serializerManager,
0178 int initialCapacity,
0179 double loadFactor,
0180 long pageSizeBytes) {
0181 super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
0182 this.taskMemoryManager = taskMemoryManager;
0183 this.blockManager = blockManager;
0184 this.serializerManager = serializerManager;
0185 this.loadFactor = loadFactor;
0186 this.loc = new Location();
0187 this.pageSizeBytes = pageSizeBytes;
0188 if (initialCapacity <= 0) {
0189 throw new IllegalArgumentException("Initial capacity must be greater than 0");
0190 }
0191 if (initialCapacity > MAX_CAPACITY) {
0192 throw new IllegalArgumentException(
0193 "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
0194 }
0195 if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
0196 throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " +
0197 TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
0198 }
0199 this.initialCapacity = initialCapacity;
0200 allocate(initialCapacity);
0201 }
0202
0203 public BytesToBytesMap(
0204 TaskMemoryManager taskMemoryManager,
0205 int initialCapacity,
0206 long pageSizeBytes) {
0207 this(
0208 taskMemoryManager,
0209 SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
0210 SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
0211 initialCapacity,
0212
0213 0.5,
0214 pageSizeBytes);
0215 }
0216
0217
0218
0219
0220 public int numKeys() { return numKeys; }
0221
0222
0223
0224
0225 public int numValues() { return numValues; }
0226
0227 public final class MapIterator implements Iterator<Location> {
0228
0229 private int numRecords;
0230 private final Location loc;
0231
0232 private MemoryBlock currentPage = null;
0233 private int recordsInPage = 0;
0234 private Object pageBaseObject;
0235 private long offsetInPage;
0236
0237
0238
0239 private boolean destructive = false;
0240 private UnsafeSorterSpillReader reader = null;
0241
0242 private MapIterator(int numRecords, Location loc, boolean destructive) {
0243 this.numRecords = numRecords;
0244 this.loc = loc;
0245 this.destructive = destructive;
0246 if (destructive) {
0247 destructiveIterator = this;
0248
0249 if (longArray != null) {
0250 freeArray(longArray);
0251 longArray = null;
0252 }
0253 }
0254 }
0255
0256 private void advanceToNextPage() {
0257
0258
0259
0260
0261
0262 MemoryBlock pageToFree = null;
0263
0264 try {
0265 synchronized (this) {
0266 int nextIdx = dataPages.indexOf(currentPage) + 1;
0267 if (destructive && currentPage != null) {
0268 dataPages.remove(currentPage);
0269 pageToFree = currentPage;
0270 nextIdx--;
0271 }
0272 if (dataPages.size() > nextIdx) {
0273 currentPage = dataPages.get(nextIdx);
0274 pageBaseObject = currentPage.getBaseObject();
0275 offsetInPage = currentPage.getBaseOffset();
0276 recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage);
0277 offsetInPage += UnsafeAlignedOffset.getUaoSize();
0278 } else {
0279 currentPage = null;
0280 if (reader != null) {
0281 handleFailedDelete();
0282 }
0283 try {
0284 Closeables.close(reader, false);
0285 reader = spillWriters.getFirst().getReader(serializerManager);
0286 recordsInPage = -1;
0287 } catch (IOException e) {
0288
0289 Platform.throwException(e);
0290 }
0291 }
0292 }
0293 } finally {
0294 if (pageToFree != null) {
0295 freePage(pageToFree);
0296 }
0297 }
0298 }
0299
0300 @Override
0301 public boolean hasNext() {
0302 if (numRecords == 0) {
0303 if (reader != null) {
0304 handleFailedDelete();
0305 }
0306 }
0307 return numRecords > 0;
0308 }
0309
0310 @Override
0311 public Location next() {
0312 if (recordsInPage == 0) {
0313 advanceToNextPage();
0314 }
0315 numRecords--;
0316 if (currentPage != null) {
0317 int totalLength = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage);
0318 loc.with(currentPage, offsetInPage);
0319
0320 offsetInPage += UnsafeAlignedOffset.getUaoSize() + totalLength + 8;
0321 recordsInPage --;
0322 return loc;
0323 } else {
0324 assert(reader != null);
0325 if (!reader.hasNext()) {
0326 advanceToNextPage();
0327 }
0328 try {
0329 reader.loadNext();
0330 } catch (IOException e) {
0331 try {
0332 reader.close();
0333 } catch(IOException e2) {
0334 logger.error("Error while closing spill reader", e2);
0335 }
0336
0337 Platform.throwException(e);
0338 }
0339 loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength());
0340 return loc;
0341 }
0342 }
0343
0344 public synchronized long spill(long numBytes) throws IOException {
0345 if (!destructive || dataPages.size() == 1) {
0346 return 0L;
0347 }
0348
0349 updatePeakMemoryUsed();
0350
0351
0352 ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
0353
0354 long released = 0L;
0355 while (dataPages.size() > 0) {
0356 MemoryBlock block = dataPages.getLast();
0357
0358 if (block == currentPage) {
0359 break;
0360 }
0361
0362 Object base = block.getBaseObject();
0363 long offset = block.getBaseOffset();
0364 int numRecords = UnsafeAlignedOffset.getSize(base, offset);
0365 int uaoSize = UnsafeAlignedOffset.getUaoSize();
0366 offset += uaoSize;
0367 final UnsafeSorterSpillWriter writer =
0368 new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords);
0369 while (numRecords > 0) {
0370 int length = UnsafeAlignedOffset.getSize(base, offset);
0371 writer.write(base, offset + uaoSize, length, 0);
0372 offset += uaoSize + length + 8;
0373 numRecords--;
0374 }
0375 writer.close();
0376 spillWriters.add(writer);
0377
0378 dataPages.removeLast();
0379 released += block.size();
0380 freePage(block);
0381
0382 if (released >= numBytes) {
0383 break;
0384 }
0385 }
0386
0387 return released;
0388 }
0389
0390 @Override
0391 public void remove() {
0392 throw new UnsupportedOperationException();
0393 }
0394
0395 private void handleFailedDelete() {
0396
0397 File file = spillWriters.removeFirst().getFile();
0398 if (file != null && file.exists() && !file.delete()) {
0399 logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
0400 }
0401 }
0402 }
0403
0404
0405
0406
0407
0408
0409
0410
0411
0412 public MapIterator iterator() {
0413 return new MapIterator(numValues, loc, false);
0414 }
0415
0416
0417
0418
0419 public MapIterator safeIterator() {
0420 return new MapIterator(numValues, new Location(), false);
0421 }
0422
0423
0424
0425
0426
0427
0428
0429
0430
0431
0432
0433 public MapIterator destructiveIterator() {
0434 updatePeakMemoryUsed();
0435 return new MapIterator(numValues, loc, true);
0436 }
0437
0438
0439
0440
0441
0442
0443
0444 public Location lookup(Object keyBase, long keyOffset, int keyLength) {
0445 safeLookup(keyBase, keyOffset, keyLength, loc,
0446 Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42));
0447 return loc;
0448 }
0449
0450
0451
0452
0453
0454
0455
0456 public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
0457 safeLookup(keyBase, keyOffset, keyLength, loc, hash);
0458 return loc;
0459 }
0460
0461
0462
0463
0464
0465
0466 public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) {
0467 assert(longArray != null);
0468
0469 numKeyLookups++;
0470
0471 int pos = hash & mask;
0472 int step = 1;
0473 while (true) {
0474 numProbes++;
0475 if (longArray.get(pos * 2) == 0) {
0476
0477 loc.with(pos, hash, false);
0478 return;
0479 } else {
0480 long stored = longArray.get(pos * 2 + 1);
0481 if ((int) (stored) == hash) {
0482
0483 loc.with(pos, hash, true);
0484 if (loc.getKeyLength() == keyLength) {
0485 final boolean areEqual = ByteArrayMethods.arrayEquals(
0486 keyBase,
0487 keyOffset,
0488 loc.getKeyBase(),
0489 loc.getKeyOffset(),
0490 keyLength
0491 );
0492 if (areEqual) {
0493 return;
0494 }
0495 }
0496 }
0497 }
0498 pos = (pos + step) & mask;
0499 step++;
0500 }
0501 }
0502
0503
0504
0505
0506 public final class Location {
0507
0508 private int pos;
0509
0510 private boolean isDefined;
0511
0512
0513
0514
0515
0516 private int keyHashcode;
0517 private Object baseObject;
0518 private long keyOffset;
0519 private int keyLength;
0520 private long valueOffset;
0521 private int valueLength;
0522
0523
0524
0525
0526 @Nullable private MemoryBlock memoryPage;
0527
0528 private void updateAddressesAndSizes(long fullKeyAddress) {
0529 updateAddressesAndSizes(
0530 taskMemoryManager.getPage(fullKeyAddress),
0531 taskMemoryManager.getOffsetInPage(fullKeyAddress));
0532 }
0533
0534 private void updateAddressesAndSizes(final Object base, long offset) {
0535 baseObject = base;
0536 final int totalLength = UnsafeAlignedOffset.getSize(base, offset);
0537 int uaoSize = UnsafeAlignedOffset.getUaoSize();
0538 offset += uaoSize;
0539 keyLength = UnsafeAlignedOffset.getSize(base, offset);
0540 offset += uaoSize;
0541 keyOffset = offset;
0542 valueOffset = offset + keyLength;
0543 valueLength = totalLength - keyLength - uaoSize;
0544 }
0545
0546 private Location with(int pos, int keyHashcode, boolean isDefined) {
0547 assert(longArray != null);
0548 this.pos = pos;
0549 this.isDefined = isDefined;
0550 this.keyHashcode = keyHashcode;
0551 if (isDefined) {
0552 final long fullKeyAddress = longArray.get(pos * 2);
0553 updateAddressesAndSizes(fullKeyAddress);
0554 }
0555 return this;
0556 }
0557
0558 private Location with(MemoryBlock page, long offsetInPage) {
0559 this.isDefined = true;
0560 this.memoryPage = page;
0561 updateAddressesAndSizes(page.getBaseObject(), offsetInPage);
0562 return this;
0563 }
0564
0565
0566
0567
0568 private Location with(Object base, long offset, int length) {
0569 this.isDefined = true;
0570 this.memoryPage = null;
0571 baseObject = base;
0572 int uaoSize = UnsafeAlignedOffset.getUaoSize();
0573 keyOffset = offset + uaoSize;
0574 keyLength = UnsafeAlignedOffset.getSize(base, offset);
0575 valueOffset = offset + uaoSize + keyLength;
0576 valueLength = length - uaoSize - keyLength;
0577 return this;
0578 }
0579
0580
0581
0582
0583 public boolean nextValue() {
0584 assert isDefined;
0585 long nextAddr = Platform.getLong(baseObject, valueOffset + valueLength);
0586 if (nextAddr == 0) {
0587 return false;
0588 } else {
0589 updateAddressesAndSizes(nextAddr);
0590 return true;
0591 }
0592 }
0593
0594
0595
0596
0597
0598 public MemoryBlock getMemoryPage() {
0599 return this.memoryPage;
0600 }
0601
0602
0603
0604
0605 public boolean isDefined() {
0606 return isDefined;
0607 }
0608
0609
0610
0611
0612 public Object getKeyBase() {
0613 assert (isDefined);
0614 return baseObject;
0615 }
0616
0617
0618
0619
0620 public long getKeyOffset() {
0621 assert (isDefined);
0622 return keyOffset;
0623 }
0624
0625
0626
0627
0628 public Object getValueBase() {
0629 assert (isDefined);
0630 return baseObject;
0631 }
0632
0633
0634
0635
0636 public long getValueOffset() {
0637 assert (isDefined);
0638 return valueOffset;
0639 }
0640
0641
0642
0643
0644
0645 public int getKeyLength() {
0646 assert (isDefined);
0647 return keyLength;
0648 }
0649
0650
0651
0652
0653
0654 public int getValueLength() {
0655 assert (isDefined);
0656 return valueLength;
0657 }
0658
0659
0660
0661
0662
0663
0664
0665
0666
0667
0668
0669
0670
0671
0672
0673
0674
0675
0676
0677
0678
0679
0680
0681
0682
0683
0684
0685
0686
0687
0688
0689
0690
0691 public boolean append(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) {
0692 assert (klen % 8 == 0);
0693 assert (vlen % 8 == 0);
0694 assert (longArray != null);
0695
0696
0697
0698
0699 if (numKeys == MAX_CAPACITY - 1
0700
0701
0702 || !canGrowArray && numKeys >= growthThreshold) {
0703 return false;
0704 }
0705
0706
0707
0708
0709
0710 int uaoSize = UnsafeAlignedOffset.getUaoSize();
0711 final long recordLength = (2L * uaoSize) + klen + vlen + 8;
0712 if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
0713 if (!acquireNewPage(recordLength + uaoSize)) {
0714 return false;
0715 }
0716 }
0717
0718
0719 final Object base = currentPage.getBaseObject();
0720 long offset = currentPage.getBaseOffset() + pageCursor;
0721 final long recordOffset = offset;
0722 UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
0723 UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);
0724 offset += (2L * uaoSize);
0725 Platform.copyMemory(kbase, koff, base, offset, klen);
0726 offset += klen;
0727 Platform.copyMemory(vbase, voff, base, offset, vlen);
0728 offset += vlen;
0729
0730 Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0);
0731
0732
0733 offset = currentPage.getBaseOffset();
0734 UnsafeAlignedOffset.putSize(base, offset, UnsafeAlignedOffset.getSize(base, offset) + 1);
0735 pageCursor += recordLength;
0736 final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
0737 currentPage, recordOffset);
0738 longArray.set(pos * 2, storedKeyAddress);
0739 updateAddressesAndSizes(storedKeyAddress);
0740 numValues++;
0741 if (!isDefined) {
0742 numKeys++;
0743 longArray.set(pos * 2 + 1, keyHashcode);
0744 isDefined = true;
0745
0746
0747
0748 if (numKeys >= growthThreshold && longArray.size() / 2 < MAX_CAPACITY) {
0749 try {
0750 growAndRehash();
0751 } catch (SparkOutOfMemoryError oom) {
0752 canGrowArray = false;
0753 }
0754 }
0755 }
0756 return true;
0757 }
0758 }
0759
0760
0761
0762
0763
0764 private boolean acquireNewPage(long required) {
0765 try {
0766 currentPage = allocatePage(required);
0767 } catch (SparkOutOfMemoryError e) {
0768 return false;
0769 }
0770 dataPages.add(currentPage);
0771 UnsafeAlignedOffset.putSize(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
0772 pageCursor = UnsafeAlignedOffset.getUaoSize();
0773 return true;
0774 }
0775
0776 @Override
0777 public long spill(long size, MemoryConsumer trigger) throws IOException {
0778 if (trigger != this && destructiveIterator != null) {
0779 return destructiveIterator.spill(size);
0780 }
0781 return 0L;
0782 }
0783
0784
0785
0786
0787
0788
0789
0790 private void allocate(int capacity) {
0791 assert (capacity >= 0);
0792 capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
0793 assert (capacity <= MAX_CAPACITY);
0794 longArray = allocateArray(capacity * 2L);
0795 longArray.zeroOut();
0796
0797 this.growthThreshold = (int) (capacity * loadFactor);
0798 this.mask = capacity - 1;
0799 }
0800
0801
0802
0803
0804
0805
0806
0807 public void free() {
0808 updatePeakMemoryUsed();
0809 if (longArray != null) {
0810 freeArray(longArray);
0811 longArray = null;
0812 }
0813 Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
0814 while (dataPagesIterator.hasNext()) {
0815 MemoryBlock dataPage = dataPagesIterator.next();
0816 dataPagesIterator.remove();
0817 freePage(dataPage);
0818 }
0819 assert(dataPages.isEmpty());
0820
0821 while (!spillWriters.isEmpty()) {
0822 File file = spillWriters.removeFirst().getFile();
0823 if (file != null && file.exists()) {
0824 if (!file.delete()) {
0825 logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
0826 }
0827 }
0828 }
0829 }
0830
0831 public TaskMemoryManager getTaskMemoryManager() {
0832 return taskMemoryManager;
0833 }
0834
0835 public long getPageSizeBytes() {
0836 return pageSizeBytes;
0837 }
0838
0839
0840
0841
0842 public long getTotalMemoryConsumption() {
0843 long totalDataPagesSize = 0L;
0844 for (MemoryBlock dataPage : dataPages) {
0845 totalDataPagesSize += dataPage.size();
0846 }
0847 return totalDataPagesSize + ((longArray != null) ? longArray.memoryBlock().size() : 0L);
0848 }
0849
0850 private void updatePeakMemoryUsed() {
0851 long mem = getTotalMemoryConsumption();
0852 if (mem > peakMemoryUsedBytes) {
0853 peakMemoryUsedBytes = mem;
0854 }
0855 }
0856
0857
0858
0859
0860 public long getPeakMemoryUsedBytes() {
0861 updatePeakMemoryUsed();
0862 return peakMemoryUsedBytes;
0863 }
0864
0865
0866
0867
0868 public double getAvgHashProbeBucketListIterations() {
0869 return (1.0 * numProbes) / numKeyLookups;
0870 }
0871
0872 @VisibleForTesting
0873 public int getNumDataPages() {
0874 return dataPages.size();
0875 }
0876
0877
0878
0879
0880 public LongArray getArray() {
0881 assert(longArray != null);
0882 return longArray;
0883 }
0884
0885
0886
0887
0888 public void reset() {
0889 updatePeakMemoryUsed();
0890 numKeys = 0;
0891 numValues = 0;
0892 freeArray(longArray);
0893 longArray = null;
0894 while (dataPages.size() > 0) {
0895 MemoryBlock dataPage = dataPages.removeLast();
0896 freePage(dataPage);
0897 }
0898 allocate(initialCapacity);
0899 canGrowArray = true;
0900 currentPage = null;
0901 pageCursor = 0;
0902 }
0903
0904
0905
0906
0907 @VisibleForTesting
0908 void growAndRehash() {
0909 assert(longArray != null);
0910
0911
0912 final LongArray oldLongArray = longArray;
0913 final int oldCapacity = (int) oldLongArray.size() / 2;
0914
0915
0916 allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
0917
0918
0919 for (int i = 0; i < oldLongArray.size(); i += 2) {
0920 final long keyPointer = oldLongArray.get(i);
0921 if (keyPointer == 0) {
0922 continue;
0923 }
0924 final int hashcode = (int) oldLongArray.get(i + 1);
0925 int newPos = hashcode & mask;
0926 int step = 1;
0927 while (longArray.get(newPos * 2) != 0) {
0928 newPos = (newPos + step) & mask;
0929 step++;
0930 }
0931 longArray.set(newPos * 2, keyPointer);
0932 longArray.set(newPos * 2 + 1, hashcode);
0933 }
0934 freeArray(oldLongArray);
0935 }
0936 }