0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.client;
0019
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.net.SocketAddress;
0023 import java.nio.ByteBuffer;
0024 import java.util.UUID;
0025 import java.util.concurrent.ExecutionException;
0026 import java.util.concurrent.TimeUnit;
0027 import javax.annotation.Nullable;
0028
0029 import com.google.common.annotations.VisibleForTesting;
0030 import com.google.common.base.Preconditions;
0031 import com.google.common.base.Throwables;
0032 import com.google.common.util.concurrent.SettableFuture;
0033 import io.netty.channel.Channel;
0034 import io.netty.util.concurrent.Future;
0035 import io.netty.util.concurrent.GenericFutureListener;
0036 import org.apache.commons.lang3.builder.ToStringBuilder;
0037 import org.apache.commons.lang3.builder.ToStringStyle;
0038 import org.slf4j.Logger;
0039 import org.slf4j.LoggerFactory;
0040
0041 import org.apache.spark.network.buffer.ManagedBuffer;
0042 import org.apache.spark.network.buffer.NioManagedBuffer;
0043 import org.apache.spark.network.protocol.*;
0044
0045 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073 public class TransportClient implements Closeable {
0074 private static final Logger logger = LoggerFactory.getLogger(TransportClient.class);
0075
0076 private final Channel channel;
0077 private final TransportResponseHandler handler;
0078 @Nullable private String clientId;
0079 private volatile boolean timedOut;
0080
0081 public TransportClient(Channel channel, TransportResponseHandler handler) {
0082 this.channel = Preconditions.checkNotNull(channel);
0083 this.handler = Preconditions.checkNotNull(handler);
0084 this.timedOut = false;
0085 }
0086
0087 public Channel getChannel() {
0088 return channel;
0089 }
0090
0091 public boolean isActive() {
0092 return !timedOut && (channel.isOpen() || channel.isActive());
0093 }
0094
0095 public SocketAddress getSocketAddress() {
0096 return channel.remoteAddress();
0097 }
0098
0099
0100
0101
0102
0103
0104 public String getClientId() {
0105 return clientId;
0106 }
0107
0108
0109
0110
0111
0112
0113 public void setClientId(String id) {
0114 Preconditions.checkState(clientId == null, "Client ID has already been set.");
0115 this.clientId = id;
0116 }
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133 public void fetchChunk(
0134 long streamId,
0135 int chunkIndex,
0136 ChunkReceivedCallback callback) {
0137 if (logger.isDebugEnabled()) {
0138 logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel));
0139 }
0140
0141 StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
0142 StdChannelListener listener = new StdChannelListener(streamChunkId) {
0143 @Override
0144 void handleFailure(String errorMsg, Throwable cause) {
0145 handler.removeFetchRequest(streamChunkId);
0146 callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
0147 }
0148 };
0149 handler.addFetchRequest(streamChunkId, callback);
0150
0151 channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
0152 }
0153
0154
0155
0156
0157
0158
0159
0160 public void stream(String streamId, StreamCallback callback) {
0161 StdChannelListener listener = new StdChannelListener(streamId) {
0162 @Override
0163 void handleFailure(String errorMsg, Throwable cause) throws Exception {
0164 callback.onFailure(streamId, new IOException(errorMsg, cause));
0165 }
0166 };
0167 if (logger.isDebugEnabled()) {
0168 logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel));
0169 }
0170
0171
0172
0173
0174 synchronized (this) {
0175 handler.addStreamCallback(streamId, callback);
0176 channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
0177 }
0178 }
0179
0180
0181
0182
0183
0184
0185
0186
0187
0188 public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
0189 if (logger.isTraceEnabled()) {
0190 logger.trace("Sending RPC to {}", getRemoteAddress(channel));
0191 }
0192
0193 long requestId = requestId();
0194 handler.addRpcRequest(requestId, callback);
0195
0196 RpcChannelListener listener = new RpcChannelListener(requestId, callback);
0197 channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
0198 .addListener(listener);
0199
0200 return requestId;
0201 }
0202
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212
0213
0214 public long uploadStream(
0215 ManagedBuffer meta,
0216 ManagedBuffer data,
0217 RpcResponseCallback callback) {
0218 if (logger.isTraceEnabled()) {
0219 logger.trace("Sending RPC to {}", getRemoteAddress(channel));
0220 }
0221
0222 long requestId = requestId();
0223 handler.addRpcRequest(requestId, callback);
0224
0225 RpcChannelListener listener = new RpcChannelListener(requestId, callback);
0226 channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener);
0227
0228 return requestId;
0229 }
0230
0231
0232
0233
0234
0235 public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
0236 final SettableFuture<ByteBuffer> result = SettableFuture.create();
0237
0238 sendRpc(message, new RpcResponseCallback() {
0239 @Override
0240 public void onSuccess(ByteBuffer response) {
0241 try {
0242 ByteBuffer copy = ByteBuffer.allocate(response.remaining());
0243 copy.put(response);
0244
0245 copy.flip();
0246 result.set(copy);
0247 } catch (Throwable t) {
0248 logger.warn("Error in responding PRC callback", t);
0249 result.setException(t);
0250 }
0251 }
0252
0253 @Override
0254 public void onFailure(Throwable e) {
0255 result.setException(e);
0256 }
0257 });
0258
0259 try {
0260 return result.get(timeoutMs, TimeUnit.MILLISECONDS);
0261 } catch (ExecutionException e) {
0262 throw Throwables.propagate(e.getCause());
0263 } catch (Exception e) {
0264 throw Throwables.propagate(e);
0265 }
0266 }
0267
0268
0269
0270
0271
0272
0273
0274 public void send(ByteBuffer message) {
0275 channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
0276 }
0277
0278
0279
0280
0281
0282
0283 public void removeRpcRequest(long requestId) {
0284 handler.removeRpcRequest(requestId);
0285 }
0286
0287
0288 public void timeOut() {
0289 this.timedOut = true;
0290 }
0291
0292 @VisibleForTesting
0293 public TransportResponseHandler getHandler() {
0294 return handler;
0295 }
0296
0297 @Override
0298 public void close() {
0299
0300 channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
0301 }
0302
0303 @Override
0304 public String toString() {
0305 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
0306 .append("remoteAdress", channel.remoteAddress())
0307 .append("clientId", clientId)
0308 .append("isActive", isActive())
0309 .toString();
0310 }
0311
0312 private static long requestId() {
0313 return Math.abs(UUID.randomUUID().getLeastSignificantBits());
0314 }
0315
0316 private class StdChannelListener
0317 implements GenericFutureListener<Future<? super Void>> {
0318 final long startTime;
0319 final Object requestId;
0320
0321 StdChannelListener(Object requestId) {
0322 this.startTime = System.currentTimeMillis();
0323 this.requestId = requestId;
0324 }
0325
0326 @Override
0327 public void operationComplete(Future<? super Void> future) throws Exception {
0328 if (future.isSuccess()) {
0329 if (logger.isTraceEnabled()) {
0330 long timeTaken = System.currentTimeMillis() - startTime;
0331 logger.trace("Sending request {} to {} took {} ms", requestId,
0332 getRemoteAddress(channel), timeTaken);
0333 }
0334 } else {
0335 String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
0336 getRemoteAddress(channel), future.cause());
0337 logger.error(errorMsg, future.cause());
0338 channel.close();
0339 try {
0340 handleFailure(errorMsg, future.cause());
0341 } catch (Exception e) {
0342 logger.error("Uncaught exception in RPC response callback handler!", e);
0343 }
0344 }
0345 }
0346
0347 void handleFailure(String errorMsg, Throwable cause) throws Exception {}
0348 }
0349
0350 private class RpcChannelListener extends StdChannelListener {
0351 final long rpcRequestId;
0352 final RpcResponseCallback callback;
0353
0354 RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
0355 super("RPC " + rpcRequestId);
0356 this.rpcRequestId = rpcRequestId;
0357 this.callback = callback;
0358 }
0359
0360 @Override
0361 void handleFailure(String errorMsg, Throwable cause) {
0362 handler.removeRpcRequest(rpcRequestId);
0363 callback.onFailure(new IOException(errorMsg, cause));
0364 }
0365 }
0366
0367 }