0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.server;
0019
0020 import java.io.IOException;
0021 import java.net.SocketAddress;
0022 import java.nio.ByteBuffer;
0023
0024 import com.google.common.base.Throwables;
0025 import io.netty.channel.Channel;
0026 import io.netty.channel.ChannelFuture;
0027
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030
0031 import org.apache.spark.network.buffer.ManagedBuffer;
0032 import org.apache.spark.network.buffer.NioManagedBuffer;
0033 import org.apache.spark.network.client.*;
0034 import org.apache.spark.network.protocol.*;
0035 import org.apache.spark.network.util.TransportFrameDecoder;
0036
0037 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0038
0039
0040
0041
0042
0043
0044
0045
0046 public class TransportRequestHandler extends MessageHandler<RequestMessage> {
0047
0048 private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
0049
0050
0051 private final Channel channel;
0052
0053
0054 private final TransportClient reverseClient;
0055
0056
0057 private final RpcHandler rpcHandler;
0058
0059
0060 private final StreamManager streamManager;
0061
0062
0063 private final long maxChunksBeingTransferred;
0064
0065
0066 private final ChunkFetchRequestHandler chunkFetchRequestHandler;
0067
0068 public TransportRequestHandler(
0069 Channel channel,
0070 TransportClient reverseClient,
0071 RpcHandler rpcHandler,
0072 Long maxChunksBeingTransferred,
0073 ChunkFetchRequestHandler chunkFetchRequestHandler) {
0074 this.channel = channel;
0075 this.reverseClient = reverseClient;
0076 this.rpcHandler = rpcHandler;
0077 this.streamManager = rpcHandler.getStreamManager();
0078 this.maxChunksBeingTransferred = maxChunksBeingTransferred;
0079 this.chunkFetchRequestHandler = chunkFetchRequestHandler;
0080 }
0081
0082 @Override
0083 public void exceptionCaught(Throwable cause) {
0084 rpcHandler.exceptionCaught(cause, reverseClient);
0085 }
0086
0087 @Override
0088 public void channelActive() {
0089 rpcHandler.channelActive(reverseClient);
0090 }
0091
0092 @Override
0093 public void channelInactive() {
0094 if (streamManager != null) {
0095 try {
0096 streamManager.connectionTerminated(channel);
0097 } catch (RuntimeException e) {
0098 logger.error("StreamManager connectionTerminated() callback failed.", e);
0099 }
0100 }
0101 rpcHandler.channelInactive(reverseClient);
0102 }
0103
0104 @Override
0105 public void handle(RequestMessage request) throws Exception {
0106 if (request instanceof ChunkFetchRequest) {
0107 chunkFetchRequestHandler.processFetchRequest(channel, (ChunkFetchRequest) request);
0108 } else if (request instanceof RpcRequest) {
0109 processRpcRequest((RpcRequest) request);
0110 } else if (request instanceof OneWayMessage) {
0111 processOneWayMessage((OneWayMessage) request);
0112 } else if (request instanceof StreamRequest) {
0113 processStreamRequest((StreamRequest) request);
0114 } else if (request instanceof UploadStream) {
0115 processStreamUpload((UploadStream) request);
0116 } else {
0117 throw new IllegalArgumentException("Unknown request type: " + request);
0118 }
0119 }
0120
0121 private void processStreamRequest(final StreamRequest req) {
0122 if (logger.isTraceEnabled()) {
0123 logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel),
0124 req.streamId);
0125 }
0126
0127 long chunksBeingTransferred = streamManager.chunksBeingTransferred();
0128 if (chunksBeingTransferred >= maxChunksBeingTransferred) {
0129 logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
0130 chunksBeingTransferred, maxChunksBeingTransferred);
0131 channel.close();
0132 return;
0133 }
0134 ManagedBuffer buf;
0135 try {
0136 buf = streamManager.openStream(req.streamId);
0137 } catch (Exception e) {
0138 logger.error(String.format(
0139 "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e);
0140 respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
0141 return;
0142 }
0143
0144 if (buf != null) {
0145 streamManager.streamBeingSent(req.streamId);
0146 respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> {
0147 streamManager.streamSent(req.streamId);
0148 });
0149 } else {
0150
0151
0152 respond(new StreamFailure(req.streamId, String.format(
0153 "Stream '%s' was not found.", req.streamId)));
0154 }
0155 }
0156
0157 private void processRpcRequest(final RpcRequest req) {
0158 try {
0159 rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
0160 @Override
0161 public void onSuccess(ByteBuffer response) {
0162 respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
0163 }
0164
0165 @Override
0166 public void onFailure(Throwable e) {
0167 respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0168 }
0169 });
0170 } catch (Exception e) {
0171 logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
0172 respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0173 } finally {
0174 req.body().release();
0175 }
0176 }
0177
0178
0179
0180
0181 private void processStreamUpload(final UploadStream req) {
0182 assert (req.body() == null);
0183 try {
0184 RpcResponseCallback callback = new RpcResponseCallback() {
0185 @Override
0186 public void onSuccess(ByteBuffer response) {
0187 respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
0188 }
0189
0190 @Override
0191 public void onFailure(Throwable e) {
0192 respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0193 }
0194 };
0195 TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
0196 channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
0197 ByteBuffer meta = req.meta.nioByteBuffer();
0198 StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback);
0199 if (streamHandler == null) {
0200 throw new NullPointerException("rpcHandler returned a null streamHandler");
0201 }
0202 StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() {
0203 @Override
0204 public void onData(String streamId, ByteBuffer buf) throws IOException {
0205 streamHandler.onData(streamId, buf);
0206 }
0207
0208 @Override
0209 public void onComplete(String streamId) throws IOException {
0210 try {
0211 streamHandler.onComplete(streamId);
0212 callback.onSuccess(ByteBuffer.allocate(0));
0213 } catch (Exception ex) {
0214 IOException ioExc = new IOException("Failure post-processing complete stream;" +
0215 " failing this rpc and leaving channel active", ex);
0216 callback.onFailure(ioExc);
0217 streamHandler.onFailure(streamId, ioExc);
0218 }
0219 }
0220
0221 @Override
0222 public void onFailure(String streamId, Throwable cause) throws IOException {
0223 callback.onFailure(new IOException("Destination failed while reading stream", cause));
0224 streamHandler.onFailure(streamId, cause);
0225 }
0226
0227 @Override
0228 public String getID() {
0229 return streamHandler.getID();
0230 }
0231 };
0232 if (req.bodyByteCount > 0) {
0233 StreamInterceptor<RequestMessage> interceptor = new StreamInterceptor<>(
0234 this, wrappedCallback.getID(), req.bodyByteCount, wrappedCallback);
0235 frameDecoder.setInterceptor(interceptor);
0236 } else {
0237 wrappedCallback.onComplete(wrappedCallback.getID());
0238 }
0239 } catch (Exception e) {
0240 logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
0241 respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0242
0243
0244
0245 channel.pipeline().fireExceptionCaught(e);
0246 } finally {
0247 req.meta.release();
0248 }
0249 }
0250
0251 private void processOneWayMessage(OneWayMessage req) {
0252 try {
0253 rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
0254 } catch (Exception e) {
0255 logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
0256 } finally {
0257 req.body().release();
0258 }
0259 }
0260
0261
0262
0263
0264
0265 private ChannelFuture respond(Encodable result) {
0266 SocketAddress remoteAddress = channel.remoteAddress();
0267 return channel.writeAndFlush(result).addListener(future -> {
0268 if (future.isSuccess()) {
0269 logger.trace("Sent result {} to client {}", result, remoteAddress);
0270 } else {
0271 logger.error(String.format("Error sending result %s to %s; closing connection",
0272 result, remoteAddress), future.cause());
0273 channel.close();
0274 }
0275 });
0276 }
0277 }