0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.memory;
0019
0020 import javax.annotation.concurrent.GuardedBy;
0021 import java.io.IOException;
0022 import java.nio.channels.ClosedByInterruptException;
0023 import java.util.Arrays;
0024 import java.util.ArrayList;
0025 import java.util.BitSet;
0026 import java.util.HashSet;
0027 import java.util.List;
0028 import java.util.Map;
0029 import java.util.TreeMap;
0030
0031 import com.google.common.annotations.VisibleForTesting;
0032 import org.slf4j.Logger;
0033 import org.slf4j.LoggerFactory;
0034
0035 import org.apache.spark.unsafe.memory.MemoryBlock;
0036 import org.apache.spark.util.Utils;
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059 public class TaskMemoryManager {
0060
0061 private static final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
0062
0063
0064 private static final int PAGE_NUMBER_BITS = 13;
0065
0066
0067 @VisibleForTesting
0068 static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS;
0069
0070
0071 private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
0072
0073
0074
0075
0076
0077
0078
0079
0080 public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
0081
0082
0083 private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093 private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
0094
0095
0096
0097
0098 private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
0099
0100 private final MemoryManager memoryManager;
0101
0102 private final long taskAttemptId;
0103
0104
0105
0106
0107
0108
0109 final MemoryMode tungstenMemoryMode;
0110
0111
0112
0113
0114 @GuardedBy("this")
0115 private final HashSet<MemoryConsumer> consumers;
0116
0117
0118
0119
0120 private volatile long acquiredButNotUsed = 0L;
0121
0122
0123
0124
0125 public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
0126 this.tungstenMemoryMode = memoryManager.tungstenMemoryMode();
0127 this.memoryManager = memoryManager;
0128 this.taskAttemptId = taskAttemptId;
0129 this.consumers = new HashSet<>();
0130 }
0131
0132
0133
0134
0135
0136
0137
0138 public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
0139 assert(required >= 0);
0140 assert(consumer != null);
0141 MemoryMode mode = consumer.getMode();
0142
0143
0144
0145
0146 synchronized (this) {
0147 long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode);
0148
0149
0150
0151 if (got < required) {
0152
0153
0154
0155
0156 TreeMap<Long, List<MemoryConsumer>> sortedConsumers = new TreeMap<>();
0157 for (MemoryConsumer c: consumers) {
0158 if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) {
0159 long key = c.getUsed();
0160 List<MemoryConsumer> list =
0161 sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1));
0162 list.add(c);
0163 }
0164 }
0165 while (!sortedConsumers.isEmpty()) {
0166
0167 Map.Entry<Long, List<MemoryConsumer>> currentEntry =
0168 sortedConsumers.ceilingEntry(required - got);
0169
0170
0171 if (currentEntry == null) {
0172 currentEntry = sortedConsumers.lastEntry();
0173 }
0174 List<MemoryConsumer> cList = currentEntry.getValue();
0175 MemoryConsumer c = cList.get(cList.size() - 1);
0176 try {
0177 long released = c.spill(required - got, consumer);
0178 if (released > 0) {
0179 logger.debug("Task {} released {} from {} for {}", taskAttemptId,
0180 Utils.bytesToString(released), c, consumer);
0181 got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
0182 if (got >= required) {
0183 break;
0184 }
0185 } else {
0186 cList.remove(cList.size() - 1);
0187 if (cList.isEmpty()) {
0188 sortedConsumers.remove(currentEntry.getKey());
0189 }
0190 }
0191 } catch (ClosedByInterruptException e) {
0192
0193 logger.error("error while calling spill() on " + c, e);
0194 throw new RuntimeException(e.getMessage());
0195 } catch (IOException e) {
0196 logger.error("error while calling spill() on " + c, e);
0197
0198 throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : "
0199 + e.getMessage());
0200
0201 }
0202 }
0203 }
0204
0205
0206 if (got < required) {
0207 try {
0208 long released = consumer.spill(required - got, consumer);
0209 if (released > 0) {
0210 logger.debug("Task {} released {} from itself ({})", taskAttemptId,
0211 Utils.bytesToString(released), consumer);
0212 got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
0213 }
0214 } catch (ClosedByInterruptException e) {
0215
0216 logger.error("error while calling spill() on " + consumer, e);
0217 throw new RuntimeException(e.getMessage());
0218 } catch (IOException e) {
0219 logger.error("error while calling spill() on " + consumer, e);
0220
0221 throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : "
0222 + e.getMessage());
0223
0224 }
0225 }
0226
0227 consumers.add(consumer);
0228 logger.debug("Task {} acquired {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
0229 return got;
0230 }
0231 }
0232
0233
0234
0235
0236 public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
0237 logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
0238 memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode());
0239 }
0240
0241
0242
0243
0244 public void showMemoryUsage() {
0245 logger.info("Memory used in task " + taskAttemptId);
0246 synchronized (this) {
0247 long memoryAccountedForByConsumers = 0;
0248 for (MemoryConsumer c: consumers) {
0249 long totalMemUsage = c.getUsed();
0250 memoryAccountedForByConsumers += totalMemUsage;
0251 if (totalMemUsage > 0) {
0252 logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage));
0253 }
0254 }
0255 long memoryNotAccountedFor =
0256 memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers;
0257 logger.info(
0258 "{} bytes of memory were used by task {} but are not associated with specific consumers",
0259 memoryNotAccountedFor, taskAttemptId);
0260 logger.info(
0261 "{} bytes of memory are used for execution and {} bytes of memory are used for storage",
0262 memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed());
0263 }
0264 }
0265
0266
0267
0268
0269 public long pageSizeBytes() {
0270 return memoryManager.pageSizeBytes();
0271 }
0272
0273
0274
0275
0276
0277
0278
0279
0280
0281
0282 public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
0283 assert(consumer != null);
0284 assert(consumer.getMode() == tungstenMemoryMode);
0285 if (size > MAXIMUM_PAGE_SIZE_BYTES) {
0286 throw new TooLargePageException(size);
0287 }
0288
0289 long acquired = acquireExecutionMemory(size, consumer);
0290 if (acquired <= 0) {
0291 return null;
0292 }
0293
0294 final int pageNumber;
0295 synchronized (this) {
0296 pageNumber = allocatedPages.nextClearBit(0);
0297 if (pageNumber >= PAGE_TABLE_SIZE) {
0298 releaseExecutionMemory(acquired, consumer);
0299 throw new IllegalStateException(
0300 "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
0301 }
0302 allocatedPages.set(pageNumber);
0303 }
0304 MemoryBlock page = null;
0305 try {
0306 page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
0307 } catch (OutOfMemoryError e) {
0308 logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
0309
0310
0311 synchronized (this) {
0312 acquiredButNotUsed += acquired;
0313 allocatedPages.clear(pageNumber);
0314 }
0315
0316 return allocatePage(size, consumer);
0317 }
0318 page.pageNumber = pageNumber;
0319 pageTable[pageNumber] = page;
0320 if (logger.isTraceEnabled()) {
0321 logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
0322 }
0323 return page;
0324 }
0325
0326
0327
0328
0329 public void freePage(MemoryBlock page, MemoryConsumer consumer) {
0330 assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
0331 "Called freePage() on memory that wasn't allocated with allocatePage()";
0332 assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
0333 "Called freePage() on a memory block that has already been freed";
0334 assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
0335 "Called freePage() on a memory block that has already been freed";
0336 assert(allocatedPages.get(page.pageNumber));
0337 pageTable[page.pageNumber] = null;
0338 synchronized (this) {
0339 allocatedPages.clear(page.pageNumber);
0340 }
0341 if (logger.isTraceEnabled()) {
0342 logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
0343 }
0344 long pageSize = page.size();
0345
0346
0347
0348 page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
0349 memoryManager.tungstenMemoryAllocator().free(page);
0350 releaseExecutionMemory(pageSize, consumer);
0351 }
0352
0353
0354
0355
0356
0357
0358
0359
0360
0361
0362
0363 public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
0364 if (tungstenMemoryMode == MemoryMode.OFF_HEAP) {
0365
0366
0367
0368 offsetInPage -= page.getBaseOffset();
0369 }
0370 return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
0371 }
0372
0373 @VisibleForTesting
0374 public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
0375 assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page";
0376 return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
0377 }
0378
0379 @VisibleForTesting
0380 public static int decodePageNumber(long pagePlusOffsetAddress) {
0381 return (int) (pagePlusOffsetAddress >>> OFFSET_BITS);
0382 }
0383
0384 private static long decodeOffset(long pagePlusOffsetAddress) {
0385 return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
0386 }
0387
0388
0389
0390
0391
0392 public Object getPage(long pagePlusOffsetAddress) {
0393 if (tungstenMemoryMode == MemoryMode.ON_HEAP) {
0394 final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
0395 assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
0396 final MemoryBlock page = pageTable[pageNumber];
0397 assert (page != null);
0398 assert (page.getBaseObject() != null);
0399 return page.getBaseObject();
0400 } else {
0401 return null;
0402 }
0403 }
0404
0405
0406
0407
0408
0409 public long getOffsetInPage(long pagePlusOffsetAddress) {
0410 final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
0411 if (tungstenMemoryMode == MemoryMode.ON_HEAP) {
0412 return offsetInPage;
0413 } else {
0414
0415
0416 final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
0417 assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
0418 final MemoryBlock page = pageTable[pageNumber];
0419 assert (page != null);
0420 return page.getBaseOffset() + offsetInPage;
0421 }
0422 }
0423
0424
0425
0426
0427
0428 public long cleanUpAllAllocatedMemory() {
0429 synchronized (this) {
0430 for (MemoryConsumer c: consumers) {
0431 if (c != null && c.getUsed() > 0) {
0432
0433 logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
0434 }
0435 }
0436 consumers.clear();
0437
0438 for (MemoryBlock page : pageTable) {
0439 if (page != null) {
0440 logger.debug("unreleased page: " + page + " in task " + taskAttemptId);
0441 page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
0442 memoryManager.tungstenMemoryAllocator().free(page);
0443 }
0444 }
0445 Arrays.fill(pageTable, null);
0446 }
0447
0448
0449 memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode);
0450
0451 return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
0452 }
0453
0454
0455
0456
0457 public long getMemoryConsumptionForThisTask() {
0458 return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId);
0459 }
0460
0461
0462
0463
0464 public MemoryMode getTungstenMemoryMode() {
0465 return tungstenMemoryMode;
0466 }
0467 }