Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.server;
0019 
0020 import java.io.IOException;
0021 import java.net.SocketAddress;
0022 import java.nio.ByteBuffer;
0023 
0024 import com.google.common.base.Throwables;
0025 import io.netty.channel.Channel;
0026 import io.netty.channel.ChannelFuture;
0027 
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030 
0031 import org.apache.spark.network.buffer.ManagedBuffer;
0032 import org.apache.spark.network.buffer.NioManagedBuffer;
0033 import org.apache.spark.network.client.*;
0034 import org.apache.spark.network.protocol.*;
0035 import org.apache.spark.network.util.TransportFrameDecoder;
0036 
0037 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0038 
0039 /**
0040  * A handler that processes requests from clients and writes chunk data back. Each handler is
0041  * attached to a single Netty channel, and keeps track of which streams have been fetched via this
0042  * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
0043  *
0044  * The messages should have been processed by the pipeline setup by {@link TransportServer}.
0045  */
0046 public class TransportRequestHandler extends MessageHandler<RequestMessage> {
0047 
0048   private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
0049 
0050   /** The Netty channel that this handler is associated with. */
0051   private final Channel channel;
0052 
0053   /** Client on the same channel allowing us to talk back to the requester. */
0054   private final TransportClient reverseClient;
0055 
0056   /** Handles all RPC messages. */
0057   private final RpcHandler rpcHandler;
0058 
0059   /** Returns each chunk part of a stream. */
0060   private final StreamManager streamManager;
0061 
0062   /** The max number of chunks being transferred and not finished yet. */
0063   private final long maxChunksBeingTransferred;
0064 
0065   /** The dedicated ChannelHandler for ChunkFetchRequest messages. */
0066   private final ChunkFetchRequestHandler chunkFetchRequestHandler;
0067 
0068   public TransportRequestHandler(
0069       Channel channel,
0070       TransportClient reverseClient,
0071       RpcHandler rpcHandler,
0072       Long maxChunksBeingTransferred,
0073       ChunkFetchRequestHandler chunkFetchRequestHandler) {
0074     this.channel = channel;
0075     this.reverseClient = reverseClient;
0076     this.rpcHandler = rpcHandler;
0077     this.streamManager = rpcHandler.getStreamManager();
0078     this.maxChunksBeingTransferred = maxChunksBeingTransferred;
0079     this.chunkFetchRequestHandler = chunkFetchRequestHandler;
0080   }
0081 
0082   @Override
0083   public void exceptionCaught(Throwable cause) {
0084     rpcHandler.exceptionCaught(cause, reverseClient);
0085   }
0086 
0087   @Override
0088   public void channelActive() {
0089     rpcHandler.channelActive(reverseClient);
0090   }
0091 
0092   @Override
0093   public void channelInactive() {
0094     if (streamManager != null) {
0095       try {
0096         streamManager.connectionTerminated(channel);
0097       } catch (RuntimeException e) {
0098         logger.error("StreamManager connectionTerminated() callback failed.", e);
0099       }
0100     }
0101     rpcHandler.channelInactive(reverseClient);
0102   }
0103 
0104   @Override
0105   public void handle(RequestMessage request) throws Exception {
0106     if (request instanceof ChunkFetchRequest) {
0107       chunkFetchRequestHandler.processFetchRequest(channel, (ChunkFetchRequest) request);
0108     } else if (request instanceof RpcRequest) {
0109       processRpcRequest((RpcRequest) request);
0110     } else if (request instanceof OneWayMessage) {
0111       processOneWayMessage((OneWayMessage) request);
0112     } else if (request instanceof StreamRequest) {
0113       processStreamRequest((StreamRequest) request);
0114     } else if (request instanceof UploadStream) {
0115       processStreamUpload((UploadStream) request);
0116     } else {
0117       throw new IllegalArgumentException("Unknown request type: " + request);
0118     }
0119   }
0120 
0121   private void processStreamRequest(final StreamRequest req) {
0122     if (logger.isTraceEnabled()) {
0123       logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel),
0124         req.streamId);
0125     }
0126 
0127     long chunksBeingTransferred = streamManager.chunksBeingTransferred();
0128     if (chunksBeingTransferred >= maxChunksBeingTransferred) {
0129       logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
0130         chunksBeingTransferred, maxChunksBeingTransferred);
0131       channel.close();
0132       return;
0133     }
0134     ManagedBuffer buf;
0135     try {
0136       buf = streamManager.openStream(req.streamId);
0137     } catch (Exception e) {
0138       logger.error(String.format(
0139         "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e);
0140       respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e)));
0141       return;
0142     }
0143 
0144     if (buf != null) {
0145       streamManager.streamBeingSent(req.streamId);
0146       respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> {
0147         streamManager.streamSent(req.streamId);
0148       });
0149     } else {
0150       // org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated
0151       // when the following error message is changed.
0152       respond(new StreamFailure(req.streamId, String.format(
0153         "Stream '%s' was not found.", req.streamId)));
0154     }
0155   }
0156 
0157   private void processRpcRequest(final RpcRequest req) {
0158     try {
0159       rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
0160         @Override
0161         public void onSuccess(ByteBuffer response) {
0162           respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
0163         }
0164 
0165         @Override
0166         public void onFailure(Throwable e) {
0167           respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0168         }
0169       });
0170     } catch (Exception e) {
0171       logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
0172       respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0173     } finally {
0174       req.body().release();
0175     }
0176   }
0177 
0178   /**
0179    * Handle a request from the client to upload a stream of data.
0180    */
0181   private void processStreamUpload(final UploadStream req) {
0182     assert (req.body() == null);
0183     try {
0184       RpcResponseCallback callback = new RpcResponseCallback() {
0185         @Override
0186         public void onSuccess(ByteBuffer response) {
0187           respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
0188         }
0189 
0190         @Override
0191         public void onFailure(Throwable e) {
0192           respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0193         }
0194       };
0195       TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
0196           channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
0197       ByteBuffer meta = req.meta.nioByteBuffer();
0198       StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback);
0199       if (streamHandler == null) {
0200         throw new NullPointerException("rpcHandler returned a null streamHandler");
0201       }
0202       StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() {
0203         @Override
0204         public void onData(String streamId, ByteBuffer buf) throws IOException {
0205           streamHandler.onData(streamId, buf);
0206         }
0207 
0208         @Override
0209         public void onComplete(String streamId) throws IOException {
0210            try {
0211              streamHandler.onComplete(streamId);
0212              callback.onSuccess(ByteBuffer.allocate(0));
0213            } catch (Exception ex) {
0214              IOException ioExc = new IOException("Failure post-processing complete stream;" +
0215                " failing this rpc and leaving channel active", ex);
0216              callback.onFailure(ioExc);
0217              streamHandler.onFailure(streamId, ioExc);
0218            }
0219         }
0220 
0221         @Override
0222         public void onFailure(String streamId, Throwable cause) throws IOException {
0223           callback.onFailure(new IOException("Destination failed while reading stream", cause));
0224           streamHandler.onFailure(streamId, cause);
0225         }
0226 
0227         @Override
0228         public String getID() {
0229           return streamHandler.getID();
0230         }
0231       };
0232       if (req.bodyByteCount > 0) {
0233         StreamInterceptor<RequestMessage> interceptor = new StreamInterceptor<>(
0234           this, wrappedCallback.getID(), req.bodyByteCount, wrappedCallback);
0235         frameDecoder.setInterceptor(interceptor);
0236       } else {
0237         wrappedCallback.onComplete(wrappedCallback.getID());
0238       }
0239     } catch (Exception e) {
0240       logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
0241       respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
0242       // We choose to totally fail the channel, rather than trying to recover as we do in other
0243       // cases.  We don't know how many bytes of the stream the client has already sent for the
0244       // stream, it's not worth trying to recover.
0245       channel.pipeline().fireExceptionCaught(e);
0246     } finally {
0247       req.meta.release();
0248     }
0249   }
0250 
0251   private void processOneWayMessage(OneWayMessage req) {
0252     try {
0253       rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
0254     } catch (Exception e) {
0255       logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
0256     } finally {
0257       req.body().release();
0258     }
0259   }
0260 
0261   /**
0262    * Responds to a single message with some Encodable object. If a failure occurs while sending,
0263    * it will be logged and the channel closed.
0264    */
0265   private ChannelFuture respond(Encodable result) {
0266     SocketAddress remoteAddress = channel.remoteAddress();
0267     return channel.writeAndFlush(result).addListener(future -> {
0268       if (future.isSuccess()) {
0269         logger.trace("Sent result {} to client {}", result, remoteAddress);
0270       } else {
0271         logger.error(String.format("Error sending result %s to %s; closing connection",
0272           result, remoteAddress), future.cause());
0273         channel.close();
0274       }
0275     });
0276   }
0277 }