0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.shuffle;
0019
0020 import java.io.File;
0021 import java.io.IOException;
0022 import java.nio.ByteBuffer;
0023 import java.util.HashMap;
0024 import java.util.Iterator;
0025 import java.util.Map;
0026 import java.util.function.Function;
0027
0028 import com.codahale.metrics.Gauge;
0029 import com.codahale.metrics.Meter;
0030 import com.codahale.metrics.Metric;
0031 import com.codahale.metrics.MetricSet;
0032 import com.codahale.metrics.Timer;
0033 import com.codahale.metrics.Counter;
0034 import com.google.common.annotations.VisibleForTesting;
0035 import org.slf4j.Logger;
0036 import org.slf4j.LoggerFactory;
0037
0038 import org.apache.spark.network.buffer.ManagedBuffer;
0039 import org.apache.spark.network.client.RpcResponseCallback;
0040 import org.apache.spark.network.client.TransportClient;
0041 import org.apache.spark.network.server.OneForOneStreamManager;
0042 import org.apache.spark.network.server.RpcHandler;
0043 import org.apache.spark.network.server.StreamManager;
0044 import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
0045 import org.apache.spark.network.shuffle.protocol.*;
0046 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0047 import org.apache.spark.network.util.TransportConf;
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057 public class ExternalBlockHandler extends RpcHandler {
0058 private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class);
0059
0060 @VisibleForTesting
0061 final ExternalShuffleBlockResolver blockManager;
0062 private final OneForOneStreamManager streamManager;
0063 private final ShuffleMetrics metrics;
0064
0065 public ExternalBlockHandler(TransportConf conf, File registeredExecutorFile)
0066 throws IOException {
0067 this(new OneForOneStreamManager(),
0068 new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
0069 }
0070
0071 @VisibleForTesting
0072 public ExternalShuffleBlockResolver getBlockResolver() {
0073 return blockManager;
0074 }
0075
0076
0077 @VisibleForTesting
0078 public ExternalBlockHandler(
0079 OneForOneStreamManager streamManager,
0080 ExternalShuffleBlockResolver blockManager) {
0081 this.metrics = new ShuffleMetrics();
0082 this.streamManager = streamManager;
0083 this.blockManager = blockManager;
0084 }
0085
0086 @Override
0087 public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
0088 BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message);
0089 handleMessage(msgObj, client, callback);
0090 }
0091
0092 protected void handleMessage(
0093 BlockTransferMessage msgObj,
0094 TransportClient client,
0095 RpcResponseCallback callback) {
0096 if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
0097 final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
0098 try {
0099 int numBlockIds;
0100 long streamId;
0101 if (msgObj instanceof FetchShuffleBlocks) {
0102 FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
0103 checkAuth(client, msg.appId);
0104 numBlockIds = 0;
0105 if (msg.batchFetchEnabled) {
0106 numBlockIds = msg.mapIds.length;
0107 } else {
0108 for (int[] ids: msg.reduceIds) {
0109 numBlockIds += ids.length;
0110 }
0111 }
0112 streamId = streamManager.registerStream(client.getClientId(),
0113 new ShuffleManagedBufferIterator(msg), client.getChannel());
0114 } else {
0115
0116 OpenBlocks msg = (OpenBlocks) msgObj;
0117 numBlockIds = msg.blockIds.length;
0118 checkAuth(client, msg.appId);
0119 streamId = streamManager.registerStream(client.getClientId(),
0120 new ManagedBufferIterator(msg), client.getChannel());
0121 }
0122 if (logger.isTraceEnabled()) {
0123 logger.trace(
0124 "Registered streamId {} with {} buffers for client {} from host {}",
0125 streamId,
0126 numBlockIds,
0127 client.getClientId(),
0128 getRemoteAddress(client.getChannel()));
0129 }
0130 callback.onSuccess(new StreamHandle(streamId, numBlockIds).toByteBuffer());
0131 } finally {
0132 responseDelayContext.stop();
0133 }
0134
0135 } else if (msgObj instanceof RegisterExecutor) {
0136 final Timer.Context responseDelayContext =
0137 metrics.registerExecutorRequestLatencyMillis.time();
0138 try {
0139 RegisterExecutor msg = (RegisterExecutor) msgObj;
0140 checkAuth(client, msg.appId);
0141 blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
0142 callback.onSuccess(ByteBuffer.wrap(new byte[0]));
0143 } finally {
0144 responseDelayContext.stop();
0145 }
0146
0147 } else if (msgObj instanceof RemoveBlocks) {
0148 RemoveBlocks msg = (RemoveBlocks) msgObj;
0149 checkAuth(client, msg.appId);
0150 int numRemovedBlocks = blockManager.removeBlocks(msg.appId, msg.execId, msg.blockIds);
0151 callback.onSuccess(new BlocksRemoved(numRemovedBlocks).toByteBuffer());
0152
0153 } else if (msgObj instanceof GetLocalDirsForExecutors) {
0154 GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
0155 checkAuth(client, msg.appId);
0156 Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
0157 callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
0158
0159 } else {
0160 throw new UnsupportedOperationException("Unexpected message: " + msgObj);
0161 }
0162 }
0163
0164 @Override
0165 public void exceptionCaught(Throwable cause, TransportClient client) {
0166 metrics.caughtExceptions.inc();
0167 }
0168
0169 public MetricSet getAllMetrics() {
0170 return metrics;
0171 }
0172
0173 @Override
0174 public StreamManager getStreamManager() {
0175 return streamManager;
0176 }
0177
0178
0179
0180
0181
0182 public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
0183 blockManager.applicationRemoved(appId, cleanupLocalDirs);
0184 }
0185
0186
0187
0188
0189 public void executorRemoved(String executorId, String appId) {
0190 blockManager.executorRemoved(executorId, appId);
0191 }
0192
0193
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203 public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) {
0204 blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo);
0205 }
0206
0207 public void close() {
0208 blockManager.close();
0209 }
0210
0211 private void checkAuth(TransportClient client, String appId) {
0212 if (client.getClientId() != null && !client.getClientId().equals(appId)) {
0213 throw new SecurityException(String.format(
0214 "Client for %s not authorized for application %s.", client.getClientId(), appId));
0215 }
0216 }
0217
0218
0219
0220
0221 @VisibleForTesting
0222 public class ShuffleMetrics implements MetricSet {
0223 private final Map<String, Metric> allMetrics;
0224
0225 private final Timer openBlockRequestLatencyMillis = new Timer();
0226
0227 private final Timer registerExecutorRequestLatencyMillis = new Timer();
0228
0229 private final Meter blockTransferRateBytes = new Meter();
0230
0231 private Counter activeConnections = new Counter();
0232
0233 private Counter caughtExceptions = new Counter();
0234
0235 public ShuffleMetrics() {
0236 allMetrics = new HashMap<>();
0237 allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
0238 allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
0239 allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
0240 allMetrics.put("registeredExecutorsSize",
0241 (Gauge<Integer>) () -> blockManager.getRegisteredExecutorsSize());
0242 allMetrics.put("numActiveConnections", activeConnections);
0243 allMetrics.put("numCaughtExceptions", caughtExceptions);
0244 }
0245
0246 @Override
0247 public Map<String, Metric> getMetrics() {
0248 return allMetrics;
0249 }
0250 }
0251
0252 private class ManagedBufferIterator implements Iterator<ManagedBuffer> {
0253
0254 private int index = 0;
0255 private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
0256 private final int size;
0257
0258 ManagedBufferIterator(OpenBlocks msg) {
0259 String appId = msg.appId;
0260 String execId = msg.execId;
0261 String[] blockIds = msg.blockIds;
0262 String[] blockId0Parts = blockIds[0].split("_");
0263 if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) {
0264 final int shuffleId = Integer.parseInt(blockId0Parts[1]);
0265 final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
0266 size = mapIdAndReduceIds.length;
0267 blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
0268 mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
0269 } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
0270 final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
0271 size = rddAndSplitIds.length;
0272 blockDataForIndexFn = index -> blockManager.getRddBlockData(appId, execId,
0273 rddAndSplitIds[index], rddAndSplitIds[index + 1]);
0274 } else {
0275 throw new IllegalArgumentException("Unexpected block id format: " + blockIds[0]);
0276 }
0277 }
0278
0279 private int[] rddAndSplitIds(String[] blockIds) {
0280 final int[] rddAndSplitIds = new int[2 * blockIds.length];
0281 for (int i = 0; i < blockIds.length; i++) {
0282 String[] blockIdParts = blockIds[i].split("_");
0283 if (blockIdParts.length != 3 || !blockIdParts[0].equals("rdd")) {
0284 throw new IllegalArgumentException("Unexpected RDD block id format: " + blockIds[i]);
0285 }
0286 rddAndSplitIds[2 * i] = Integer.parseInt(blockIdParts[1]);
0287 rddAndSplitIds[2 * i + 1] = Integer.parseInt(blockIdParts[2]);
0288 }
0289 return rddAndSplitIds;
0290 }
0291
0292 private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
0293 final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
0294 for (int i = 0; i < blockIds.length; i++) {
0295 String[] blockIdParts = blockIds[i].split("_");
0296 if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
0297 throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
0298 }
0299 if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
0300 throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
0301 ", got:" + blockIds[i]);
0302 }
0303 mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
0304 mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
0305 }
0306 return mapIdAndReduceIds;
0307 }
0308
0309 @Override
0310 public boolean hasNext() {
0311 return index < size;
0312 }
0313
0314 @Override
0315 public ManagedBuffer next() {
0316 final ManagedBuffer block = blockDataForIndexFn.apply(index);
0317 index += 2;
0318 metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
0319 return block;
0320 }
0321 }
0322
0323 private class ShuffleManagedBufferIterator implements Iterator<ManagedBuffer> {
0324
0325 private int mapIdx = 0;
0326 private int reduceIdx = 0;
0327
0328 private final String appId;
0329 private final String execId;
0330 private final int shuffleId;
0331 private final long[] mapIds;
0332 private final int[][] reduceIds;
0333 private final boolean batchFetchEnabled;
0334
0335 ShuffleManagedBufferIterator(FetchShuffleBlocks msg) {
0336 appId = msg.appId;
0337 execId = msg.execId;
0338 shuffleId = msg.shuffleId;
0339 mapIds = msg.mapIds;
0340 reduceIds = msg.reduceIds;
0341 batchFetchEnabled = msg.batchFetchEnabled;
0342 }
0343
0344 @Override
0345 public boolean hasNext() {
0346
0347
0348
0349 assert(mapIds.length != 0 && mapIds.length == reduceIds.length);
0350 return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length;
0351 }
0352
0353 @Override
0354 public ManagedBuffer next() {
0355 ManagedBuffer block;
0356 if (!batchFetchEnabled) {
0357 block = blockManager.getBlockData(
0358 appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]);
0359 if (reduceIdx < reduceIds[mapIdx].length - 1) {
0360 reduceIdx += 1;
0361 } else {
0362 reduceIdx = 0;
0363 mapIdx += 1;
0364 }
0365 } else {
0366 assert(reduceIds[mapIdx].length == 2);
0367 block = blockManager.getContinuousBlocksData(appId, execId, shuffleId, mapIds[mapIdx],
0368 reduceIds[mapIdx][0], reduceIds[mapIdx][1]);
0369 mapIdx += 1;
0370 }
0371 metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
0372 return block;
0373 }
0374 }
0375
0376 @Override
0377 public void channelActive(TransportClient client) {
0378 metrics.activeConnections.inc();
0379 super.channelActive(client);
0380 }
0381
0382 @Override
0383 public void channelInactive(TransportClient client) {
0384 metrics.activeConnections.dec();
0385 super.channelInactive(client);
0386 }
0387
0388 }