0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.shuffle;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.util.Arrays;
0023 import java.util.List;
0024 import java.util.Map;
0025 import java.util.concurrent.CompletableFuture;
0026 import java.util.concurrent.Future;
0027
0028 import com.codahale.metrics.MetricSet;
0029 import com.google.common.collect.Lists;
0030 import org.apache.spark.network.client.RpcResponseCallback;
0031 import org.apache.spark.network.client.TransportClient;
0032 import org.apache.spark.network.client.TransportClientBootstrap;
0033 import org.apache.spark.network.client.TransportClientFactory;
0034 import org.apache.spark.network.shuffle.protocol.*;
0035 import org.slf4j.Logger;
0036 import org.slf4j.LoggerFactory;
0037
0038 import org.apache.spark.network.TransportContext;
0039 import org.apache.spark.network.crypto.AuthClientBootstrap;
0040 import org.apache.spark.network.sasl.SecretKeyHolder;
0041 import org.apache.spark.network.server.NoOpRpcHandler;
0042 import org.apache.spark.network.util.TransportConf;
0043
0044
0045
0046
0047
0048
0049 public class ExternalBlockStoreClient extends BlockStoreClient {
0050 private static final Logger logger = LoggerFactory.getLogger(ExternalBlockStoreClient.class);
0051
0052 private final TransportConf conf;
0053 private final boolean authEnabled;
0054 private final SecretKeyHolder secretKeyHolder;
0055 private final long registrationTimeoutMs;
0056
0057 protected volatile TransportClientFactory clientFactory;
0058 protected String appId;
0059
0060
0061
0062
0063
0064 public ExternalBlockStoreClient(
0065 TransportConf conf,
0066 SecretKeyHolder secretKeyHolder,
0067 boolean authEnabled,
0068 long registrationTimeoutMs) {
0069 this.conf = conf;
0070 this.secretKeyHolder = secretKeyHolder;
0071 this.authEnabled = authEnabled;
0072 this.registrationTimeoutMs = registrationTimeoutMs;
0073 }
0074
0075 protected void checkInit() {
0076 assert appId != null : "Called before init()";
0077 }
0078
0079
0080
0081
0082
0083 public void init(String appId) {
0084 this.appId = appId;
0085 TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true);
0086 List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
0087 if (authEnabled) {
0088 bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
0089 }
0090 clientFactory = context.createClientFactory(bootstraps);
0091 }
0092
0093 @Override
0094 public void fetchBlocks(
0095 String host,
0096 int port,
0097 String execId,
0098 String[] blockIds,
0099 BlockFetchingListener listener,
0100 DownloadFileManager downloadFileManager) {
0101 checkInit();
0102 logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
0103 try {
0104 RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
0105 (blockIds1, listener1) -> {
0106
0107 if (clientFactory != null) {
0108 TransportClient client = clientFactory.createClient(host, port);
0109 new OneForOneBlockFetcher(client, appId, execId,
0110 blockIds1, listener1, conf, downloadFileManager).start();
0111 } else {
0112 logger.info("This clientFactory was closed. Skipping further block fetch retries.");
0113 }
0114 };
0115
0116 int maxRetries = conf.maxIORetries();
0117 if (maxRetries > 0) {
0118
0119
0120 new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
0121 } else {
0122 blockFetchStarter.createAndStart(blockIds, listener);
0123 }
0124 } catch (Exception e) {
0125 logger.error("Exception while beginning fetchBlocks", e);
0126 for (String blockId : blockIds) {
0127 listener.onBlockFetchFailure(blockId, e);
0128 }
0129 }
0130 }
0131
0132 @Override
0133 public MetricSet shuffleMetrics() {
0134 checkInit();
0135 return clientFactory.getAllMetrics();
0136 }
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147 public void registerWithShuffleServer(
0148 String host,
0149 int port,
0150 String execId,
0151 ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException {
0152 checkInit();
0153 try (TransportClient client = clientFactory.createClient(host, port)) {
0154 ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer();
0155 client.sendRpcSync(registerMessage, registrationTimeoutMs);
0156 }
0157 }
0158
0159 public Future<Integer> removeBlocks(
0160 String host,
0161 int port,
0162 String execId,
0163 String[] blockIds) throws IOException, InterruptedException {
0164 checkInit();
0165 CompletableFuture<Integer> numRemovedBlocksFuture = new CompletableFuture<>();
0166 ByteBuffer removeBlocksMessage = new RemoveBlocks(appId, execId, blockIds).toByteBuffer();
0167 final TransportClient client = clientFactory.createClient(host, port);
0168 client.sendRpc(removeBlocksMessage, new RpcResponseCallback() {
0169 @Override
0170 public void onSuccess(ByteBuffer response) {
0171 try {
0172 BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response);
0173 numRemovedBlocksFuture.complete(((BlocksRemoved) msgObj).numRemovedBlocks);
0174 } catch (Throwable t) {
0175 logger.warn("Error trying to remove RDD blocks " + Arrays.toString(blockIds) +
0176 " via external shuffle service from executor: " + execId, t);
0177 numRemovedBlocksFuture.complete(0);
0178 } finally {
0179 client.close();
0180 }
0181 }
0182
0183 @Override
0184 public void onFailure(Throwable e) {
0185 logger.warn("Error trying to remove RDD blocks " + Arrays.toString(blockIds) +
0186 " via external shuffle service from executor: " + execId, e);
0187 numRemovedBlocksFuture.complete(0);
0188 client.close();
0189 }
0190 });
0191 return numRemovedBlocksFuture;
0192 }
0193
0194 public void getHostLocalDirs(
0195 String host,
0196 int port,
0197 String[] execIds,
0198 CompletableFuture<Map<String, String[]>> hostLocalDirsCompletable) {
0199 checkInit();
0200 GetLocalDirsForExecutors getLocalDirsMessage = new GetLocalDirsForExecutors(appId, execIds);
0201 try {
0202 TransportClient client = clientFactory.createClient(host, port);
0203 client.sendRpc(getLocalDirsMessage.toByteBuffer(), new RpcResponseCallback() {
0204 @Override
0205 public void onSuccess(ByteBuffer response) {
0206 try {
0207 BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response);
0208 hostLocalDirsCompletable.complete(
0209 ((LocalDirsForExecutors) msgObj).getLocalDirsByExec());
0210 } catch (Throwable t) {
0211 logger.warn("Error trying to get the host local dirs for " +
0212 Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service",
0213 t.getCause());
0214 hostLocalDirsCompletable.completeExceptionally(t);
0215 } finally {
0216 client.close();
0217 }
0218 }
0219
0220 @Override
0221 public void onFailure(Throwable t) {
0222 logger.warn("Error trying to get the host local dirs for " +
0223 Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service",
0224 t.getCause());
0225 hostLocalDirsCompletable.completeExceptionally(t);
0226 client.close();
0227 }
0228 });
0229 } catch (IOException | InterruptedException e) {
0230 hostLocalDirsCompletable.completeExceptionally(e);
0231 }
0232 }
0233
0234 @Override
0235 public void close() {
0236 checkInit();
0237 if (clientFactory != null) {
0238 clientFactory.close();
0239 clientFactory = null;
0240 }
0241 }
0242 }