Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.shuffle;
0019 
0020 import java.io.File;
0021 import java.io.IOException;
0022 import java.nio.ByteBuffer;
0023 import java.util.HashMap;
0024 import java.util.Iterator;
0025 import java.util.Map;
0026 import java.util.function.Function;
0027 
0028 import com.codahale.metrics.Gauge;
0029 import com.codahale.metrics.Meter;
0030 import com.codahale.metrics.Metric;
0031 import com.codahale.metrics.MetricSet;
0032 import com.codahale.metrics.Timer;
0033 import com.codahale.metrics.Counter;
0034 import com.google.common.annotations.VisibleForTesting;
0035 import org.slf4j.Logger;
0036 import org.slf4j.LoggerFactory;
0037 
0038 import org.apache.spark.network.buffer.ManagedBuffer;
0039 import org.apache.spark.network.client.RpcResponseCallback;
0040 import org.apache.spark.network.client.TransportClient;
0041 import org.apache.spark.network.server.OneForOneStreamManager;
0042 import org.apache.spark.network.server.RpcHandler;
0043 import org.apache.spark.network.server.StreamManager;
0044 import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
0045 import org.apache.spark.network.shuffle.protocol.*;
0046 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0047 import org.apache.spark.network.util.TransportConf;
0048 
0049 /**
0050  * RPC Handler for a server which can serve both RDD blocks and shuffle blocks from outside
0051  * of an Executor process.
0052  *
0053  * Handles registering executors and opening shuffle or disk persisted RDD blocks from them.
0054  * Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
0055  * is equivalent to one block.
0056  */
0057 public class ExternalBlockHandler extends RpcHandler {
0058   private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class);
0059 
0060   @VisibleForTesting
0061   final ExternalShuffleBlockResolver blockManager;
0062   private final OneForOneStreamManager streamManager;
0063   private final ShuffleMetrics metrics;
0064 
0065   public ExternalBlockHandler(TransportConf conf, File registeredExecutorFile)
0066     throws IOException {
0067     this(new OneForOneStreamManager(),
0068       new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
0069   }
0070 
0071   @VisibleForTesting
0072   public ExternalShuffleBlockResolver getBlockResolver() {
0073     return blockManager;
0074   }
0075 
0076   /** Enables mocking out the StreamManager and BlockManager. */
0077   @VisibleForTesting
0078   public ExternalBlockHandler(
0079       OneForOneStreamManager streamManager,
0080       ExternalShuffleBlockResolver blockManager) {
0081     this.metrics = new ShuffleMetrics();
0082     this.streamManager = streamManager;
0083     this.blockManager = blockManager;
0084   }
0085 
0086   @Override
0087   public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
0088     BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message);
0089     handleMessage(msgObj, client, callback);
0090   }
0091 
0092   protected void handleMessage(
0093       BlockTransferMessage msgObj,
0094       TransportClient client,
0095       RpcResponseCallback callback) {
0096     if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
0097       final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
0098       try {
0099         int numBlockIds;
0100         long streamId;
0101         if (msgObj instanceof FetchShuffleBlocks) {
0102           FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
0103           checkAuth(client, msg.appId);
0104           numBlockIds = 0;
0105           if (msg.batchFetchEnabled) {
0106             numBlockIds = msg.mapIds.length;
0107           } else {
0108             for (int[] ids: msg.reduceIds) {
0109               numBlockIds += ids.length;
0110             }
0111           }
0112           streamId = streamManager.registerStream(client.getClientId(),
0113             new ShuffleManagedBufferIterator(msg), client.getChannel());
0114         } else {
0115           // For the compatibility with the old version, still keep the support for OpenBlocks.
0116           OpenBlocks msg = (OpenBlocks) msgObj;
0117           numBlockIds = msg.blockIds.length;
0118           checkAuth(client, msg.appId);
0119           streamId = streamManager.registerStream(client.getClientId(),
0120             new ManagedBufferIterator(msg), client.getChannel());
0121         }
0122         if (logger.isTraceEnabled()) {
0123           logger.trace(
0124             "Registered streamId {} with {} buffers for client {} from host {}",
0125             streamId,
0126             numBlockIds,
0127             client.getClientId(),
0128             getRemoteAddress(client.getChannel()));
0129         }
0130         callback.onSuccess(new StreamHandle(streamId, numBlockIds).toByteBuffer());
0131       } finally {
0132         responseDelayContext.stop();
0133       }
0134 
0135     } else if (msgObj instanceof RegisterExecutor) {
0136       final Timer.Context responseDelayContext =
0137         metrics.registerExecutorRequestLatencyMillis.time();
0138       try {
0139         RegisterExecutor msg = (RegisterExecutor) msgObj;
0140         checkAuth(client, msg.appId);
0141         blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
0142         callback.onSuccess(ByteBuffer.wrap(new byte[0]));
0143       } finally {
0144         responseDelayContext.stop();
0145       }
0146 
0147     } else if (msgObj instanceof RemoveBlocks) {
0148       RemoveBlocks msg = (RemoveBlocks) msgObj;
0149       checkAuth(client, msg.appId);
0150       int numRemovedBlocks = blockManager.removeBlocks(msg.appId, msg.execId, msg.blockIds);
0151       callback.onSuccess(new BlocksRemoved(numRemovedBlocks).toByteBuffer());
0152 
0153     } else if (msgObj instanceof GetLocalDirsForExecutors) {
0154       GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
0155       checkAuth(client, msg.appId);
0156       Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
0157       callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
0158 
0159     } else {
0160       throw new UnsupportedOperationException("Unexpected message: " + msgObj);
0161     }
0162   }
0163 
0164   @Override
0165   public void exceptionCaught(Throwable cause, TransportClient client) {
0166     metrics.caughtExceptions.inc();
0167   }
0168 
0169   public MetricSet getAllMetrics() {
0170     return metrics;
0171   }
0172 
0173   @Override
0174   public StreamManager getStreamManager() {
0175     return streamManager;
0176   }
0177 
0178   /**
0179    * Removes an application (once it has been terminated), and optionally will clean up any
0180    * local directories associated with the executors of that application in a separate thread.
0181    */
0182   public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
0183     blockManager.applicationRemoved(appId, cleanupLocalDirs);
0184   }
0185 
0186   /**
0187    * Clean up any non-shuffle files in any local directories associated with an finished executor.
0188    */
0189   public void executorRemoved(String executorId, String appId) {
0190     blockManager.executorRemoved(executorId, appId);
0191   }
0192 
0193   /**
0194    * Register an (application, executor) with the given shuffle info.
0195    *
0196    * The "re-" is meant to highlight the intended use of this method -- when this service is
0197    * restarted, this is used to restore the state of executors from before the restart.  Normal
0198    * registration will happen via a message handled in receive()
0199    *
0200    * @param appExecId
0201    * @param executorInfo
0202    */
0203   public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) {
0204     blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo);
0205   }
0206 
0207   public void close() {
0208     blockManager.close();
0209   }
0210 
0211   private void checkAuth(TransportClient client, String appId) {
0212     if (client.getClientId() != null && !client.getClientId().equals(appId)) {
0213       throw new SecurityException(String.format(
0214         "Client for %s not authorized for application %s.", client.getClientId(), appId));
0215     }
0216   }
0217 
0218   /**
0219    * A simple class to wrap all shuffle service wrapper metrics
0220    */
0221   @VisibleForTesting
0222   public class ShuffleMetrics implements MetricSet {
0223     private final Map<String, Metric> allMetrics;
0224     // Time latency for open block request in ms
0225     private final Timer openBlockRequestLatencyMillis = new Timer();
0226     // Time latency for executor registration latency in ms
0227     private final Timer registerExecutorRequestLatencyMillis = new Timer();
0228     // Block transfer rate in byte per second
0229     private final Meter blockTransferRateBytes = new Meter();
0230     // Number of active connections to the shuffle service
0231     private Counter activeConnections = new Counter();
0232     // Number of exceptions caught in connections to the shuffle service
0233     private Counter caughtExceptions = new Counter();
0234 
0235     public ShuffleMetrics() {
0236       allMetrics = new HashMap<>();
0237       allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
0238       allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
0239       allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
0240       allMetrics.put("registeredExecutorsSize",
0241                      (Gauge<Integer>) () -> blockManager.getRegisteredExecutorsSize());
0242       allMetrics.put("numActiveConnections", activeConnections);
0243       allMetrics.put("numCaughtExceptions", caughtExceptions);
0244     }
0245 
0246     @Override
0247     public Map<String, Metric> getMetrics() {
0248       return allMetrics;
0249     }
0250   }
0251 
0252   private class ManagedBufferIterator implements Iterator<ManagedBuffer> {
0253 
0254     private int index = 0;
0255     private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
0256     private final int size;
0257 
0258     ManagedBufferIterator(OpenBlocks msg) {
0259       String appId = msg.appId;
0260       String execId = msg.execId;
0261       String[] blockIds = msg.blockIds;
0262       String[] blockId0Parts = blockIds[0].split("_");
0263       if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) {
0264         final int shuffleId = Integer.parseInt(blockId0Parts[1]);
0265         final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId);
0266         size = mapIdAndReduceIds.length;
0267         blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId,
0268           mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
0269       } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) {
0270         final int[] rddAndSplitIds = rddAndSplitIds(blockIds);
0271         size = rddAndSplitIds.length;
0272         blockDataForIndexFn = index -> blockManager.getRddBlockData(appId, execId,
0273           rddAndSplitIds[index], rddAndSplitIds[index + 1]);
0274       } else {
0275         throw new IllegalArgumentException("Unexpected block id format: " + blockIds[0]);
0276       }
0277     }
0278 
0279     private int[] rddAndSplitIds(String[] blockIds) {
0280       final int[] rddAndSplitIds = new int[2 * blockIds.length];
0281       for (int i = 0; i < blockIds.length; i++) {
0282         String[] blockIdParts = blockIds[i].split("_");
0283         if (blockIdParts.length != 3 || !blockIdParts[0].equals("rdd")) {
0284           throw new IllegalArgumentException("Unexpected RDD block id format: " + blockIds[i]);
0285         }
0286         rddAndSplitIds[2 * i] = Integer.parseInt(blockIdParts[1]);
0287         rddAndSplitIds[2 * i + 1] = Integer.parseInt(blockIdParts[2]);
0288       }
0289       return rddAndSplitIds;
0290     }
0291 
0292     private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
0293       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
0294       for (int i = 0; i < blockIds.length; i++) {
0295         String[] blockIdParts = blockIds[i].split("_");
0296         if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
0297           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
0298         }
0299         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
0300           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
0301             ", got:" + blockIds[i]);
0302         }
0303         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
0304         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
0305       }
0306       return mapIdAndReduceIds;
0307     }
0308 
0309     @Override
0310     public boolean hasNext() {
0311       return index < size;
0312     }
0313 
0314     @Override
0315     public ManagedBuffer next() {
0316       final ManagedBuffer block = blockDataForIndexFn.apply(index);
0317       index += 2;
0318       metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
0319       return block;
0320     }
0321   }
0322 
0323   private class ShuffleManagedBufferIterator implements Iterator<ManagedBuffer> {
0324 
0325     private int mapIdx = 0;
0326     private int reduceIdx = 0;
0327 
0328     private final String appId;
0329     private final String execId;
0330     private final int shuffleId;
0331     private final long[] mapIds;
0332     private final int[][] reduceIds;
0333     private final boolean batchFetchEnabled;
0334 
0335     ShuffleManagedBufferIterator(FetchShuffleBlocks msg) {
0336       appId = msg.appId;
0337       execId = msg.execId;
0338       shuffleId = msg.shuffleId;
0339       mapIds = msg.mapIds;
0340       reduceIds = msg.reduceIds;
0341       batchFetchEnabled = msg.batchFetchEnabled;
0342     }
0343 
0344     @Override
0345     public boolean hasNext() {
0346       // mapIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks
0347       // must have non-empty mapIds and reduceIds, see the checking logic in
0348       // OneForOneBlockFetcher.
0349       assert(mapIds.length != 0 && mapIds.length == reduceIds.length);
0350       return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length;
0351     }
0352 
0353     @Override
0354     public ManagedBuffer next() {
0355       ManagedBuffer block;
0356       if (!batchFetchEnabled) {
0357         block = blockManager.getBlockData(
0358           appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]);
0359         if (reduceIdx < reduceIds[mapIdx].length - 1) {
0360           reduceIdx += 1;
0361         } else {
0362           reduceIdx = 0;
0363           mapIdx += 1;
0364         }
0365       } else {
0366         assert(reduceIds[mapIdx].length == 2);
0367         block = blockManager.getContinuousBlocksData(appId, execId, shuffleId, mapIds[mapIdx],
0368           reduceIds[mapIdx][0], reduceIds[mapIdx][1]);
0369         mapIdx += 1;
0370       }
0371       metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
0372       return block;
0373     }
0374   }
0375 
0376   @Override
0377   public void channelActive(TransportClient client) {
0378     metrics.activeConnections.inc();
0379     super.channelActive(client);
0380   }
0381 
0382   @Override
0383   public void channelInactive(TransportClient client) {
0384     metrics.activeConnections.dec();
0385     super.channelInactive(client);
0386   }
0387 
0388 }