Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.client;
0019 
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.net.SocketAddress;
0023 import java.nio.ByteBuffer;
0024 import java.util.UUID;
0025 import java.util.concurrent.ExecutionException;
0026 import java.util.concurrent.TimeUnit;
0027 import javax.annotation.Nullable;
0028 
0029 import com.google.common.annotations.VisibleForTesting;
0030 import com.google.common.base.Preconditions;
0031 import com.google.common.base.Throwables;
0032 import com.google.common.util.concurrent.SettableFuture;
0033 import io.netty.channel.Channel;
0034 import io.netty.util.concurrent.Future;
0035 import io.netty.util.concurrent.GenericFutureListener;
0036 import org.apache.commons.lang3.builder.ToStringBuilder;
0037 import org.apache.commons.lang3.builder.ToStringStyle;
0038 import org.slf4j.Logger;
0039 import org.slf4j.LoggerFactory;
0040 
0041 import org.apache.spark.network.buffer.ManagedBuffer;
0042 import org.apache.spark.network.buffer.NioManagedBuffer;
0043 import org.apache.spark.network.protocol.*;
0044 
0045 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0046 
0047 /**
0048  * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
0049  * efficient transfer of a large amount of data, broken up into chunks with size ranging from
0050  * hundreds of KB to a few MB.
0051  *
0052  * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
0053  * the actual setup of the streams is done outside the scope of the transport layer. The convenience
0054  * method "sendRPC" is provided to enable control plane communication between the client and server
0055  * to perform this setup.
0056  *
0057  * For example, a typical workflow might be:
0058  * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
0059  * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
0060  * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
0061  * ...
0062  * client.sendRPC(new CloseStream(100))
0063  *
0064  * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
0065  * TransportClient may be used for multiple streams, but any given stream must be restricted to a
0066  * single client, in order to avoid out-of-order responses.
0067  *
0068  * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
0069  * responsible for handling responses from the server.
0070  *
0071  * Concurrency: thread safe and can be called from multiple threads.
0072  */
0073 public class TransportClient implements Closeable {
0074   private static final Logger logger = LoggerFactory.getLogger(TransportClient.class);
0075 
0076   private final Channel channel;
0077   private final TransportResponseHandler handler;
0078   @Nullable private String clientId;
0079   private volatile boolean timedOut;
0080 
0081   public TransportClient(Channel channel, TransportResponseHandler handler) {
0082     this.channel = Preconditions.checkNotNull(channel);
0083     this.handler = Preconditions.checkNotNull(handler);
0084     this.timedOut = false;
0085   }
0086 
0087   public Channel getChannel() {
0088     return channel;
0089   }
0090 
0091   public boolean isActive() {
0092     return !timedOut && (channel.isOpen() || channel.isActive());
0093   }
0094 
0095   public SocketAddress getSocketAddress() {
0096     return channel.remoteAddress();
0097   }
0098 
0099   /**
0100    * Returns the ID used by the client to authenticate itself when authentication is enabled.
0101    *
0102    * @return The client ID, or null if authentication is disabled.
0103    */
0104   public String getClientId() {
0105     return clientId;
0106   }
0107 
0108   /**
0109    * Sets the authenticated client ID. This is meant to be used by the authentication layer.
0110    *
0111    * Trying to set a different client ID after it's been set will result in an exception.
0112    */
0113   public void setClientId(String id) {
0114     Preconditions.checkState(clientId == null, "Client ID has already been set.");
0115     this.clientId = id;
0116   }
0117 
0118   /**
0119    * Requests a single chunk from the remote side, from the pre-negotiated streamId.
0120    *
0121    * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
0122    * some streams may not support this.
0123    *
0124    * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
0125    * to be returned in the same order that they were requested, assuming only a single
0126    * TransportClient is used to fetch the chunks.
0127    *
0128    * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
0129    *                 be agreed upon by client and server beforehand.
0130    * @param chunkIndex 0-based index of the chunk to fetch
0131    * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
0132    */
0133   public void fetchChunk(
0134       long streamId,
0135       int chunkIndex,
0136       ChunkReceivedCallback callback) {
0137     if (logger.isDebugEnabled()) {
0138       logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel));
0139     }
0140 
0141     StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
0142     StdChannelListener listener = new StdChannelListener(streamChunkId) {
0143       @Override
0144       void handleFailure(String errorMsg, Throwable cause) {
0145         handler.removeFetchRequest(streamChunkId);
0146         callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
0147       }
0148     };
0149     handler.addFetchRequest(streamChunkId, callback);
0150 
0151     channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
0152   }
0153 
0154   /**
0155    * Request to stream the data with the given stream ID from the remote end.
0156    *
0157    * @param streamId The stream to fetch.
0158    * @param callback Object to call with the stream data.
0159    */
0160   public void stream(String streamId, StreamCallback callback) {
0161     StdChannelListener listener = new StdChannelListener(streamId) {
0162       @Override
0163       void handleFailure(String errorMsg, Throwable cause) throws Exception {
0164         callback.onFailure(streamId, new IOException(errorMsg, cause));
0165       }
0166     };
0167     if (logger.isDebugEnabled()) {
0168       logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel));
0169     }
0170 
0171     // Need to synchronize here so that the callback is added to the queue and the RPC is
0172     // written to the socket atomically, so that callbacks are called in the right order
0173     // when responses arrive.
0174     synchronized (this) {
0175       handler.addStreamCallback(streamId, callback);
0176       channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
0177     }
0178   }
0179 
0180   /**
0181    * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
0182    * with the server's response or upon any failure.
0183    *
0184    * @param message The message to send.
0185    * @param callback Callback to handle the RPC's reply.
0186    * @return The RPC's id.
0187    */
0188   public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
0189     if (logger.isTraceEnabled()) {
0190       logger.trace("Sending RPC to {}", getRemoteAddress(channel));
0191     }
0192 
0193     long requestId = requestId();
0194     handler.addRpcRequest(requestId, callback);
0195 
0196     RpcChannelListener listener = new RpcChannelListener(requestId, callback);
0197     channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
0198       .addListener(listener);
0199 
0200     return requestId;
0201   }
0202 
0203   /**
0204    * Send data to the remote end as a stream.  This differs from stream() in that this is a request
0205    * to *send* data to the remote end, not to receive it from the remote.
0206    *
0207    * @param meta meta data associated with the stream, which will be read completely on the
0208    *             receiving end before the stream itself.
0209    * @param data this will be streamed to the remote end to allow for transferring large amounts
0210    *             of data without reading into memory.
0211    * @param callback handles the reply -- onSuccess will only be called when both message and data
0212    *                 are received successfully.
0213    */
0214   public long uploadStream(
0215       ManagedBuffer meta,
0216       ManagedBuffer data,
0217       RpcResponseCallback callback) {
0218     if (logger.isTraceEnabled()) {
0219       logger.trace("Sending RPC to {}", getRemoteAddress(channel));
0220     }
0221 
0222     long requestId = requestId();
0223     handler.addRpcRequest(requestId, callback);
0224 
0225     RpcChannelListener listener = new RpcChannelListener(requestId, callback);
0226     channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener);
0227 
0228     return requestId;
0229   }
0230 
0231   /**
0232    * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
0233    * a specified timeout for a response.
0234    */
0235   public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
0236     final SettableFuture<ByteBuffer> result = SettableFuture.create();
0237 
0238     sendRpc(message, new RpcResponseCallback() {
0239       @Override
0240       public void onSuccess(ByteBuffer response) {
0241         try {
0242           ByteBuffer copy = ByteBuffer.allocate(response.remaining());
0243           copy.put(response);
0244           // flip "copy" to make it readable
0245           copy.flip();
0246           result.set(copy);
0247         } catch (Throwable t) {
0248           logger.warn("Error in responding PRC callback", t);
0249           result.setException(t);
0250         }
0251       }
0252 
0253       @Override
0254       public void onFailure(Throwable e) {
0255         result.setException(e);
0256       }
0257     });
0258 
0259     try {
0260       return result.get(timeoutMs, TimeUnit.MILLISECONDS);
0261     } catch (ExecutionException e) {
0262       throw Throwables.propagate(e.getCause());
0263     } catch (Exception e) {
0264       throw Throwables.propagate(e);
0265     }
0266   }
0267 
0268   /**
0269    * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
0270    * message, and no delivery guarantees are made.
0271    *
0272    * @param message The message to send.
0273    */
0274   public void send(ByteBuffer message) {
0275     channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
0276   }
0277 
0278   /**
0279    * Removes any state associated with the given RPC.
0280    *
0281    * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}.
0282    */
0283   public void removeRpcRequest(long requestId) {
0284     handler.removeRpcRequest(requestId);
0285   }
0286 
0287   /** Mark this channel as having timed out. */
0288   public void timeOut() {
0289     this.timedOut = true;
0290   }
0291 
0292   @VisibleForTesting
0293   public TransportResponseHandler getHandler() {
0294     return handler;
0295   }
0296 
0297   @Override
0298   public void close() {
0299     // close is a local operation and should finish with milliseconds; timeout just to be safe
0300     channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
0301   }
0302 
0303   @Override
0304   public String toString() {
0305     return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
0306       .append("remoteAdress", channel.remoteAddress())
0307       .append("clientId", clientId)
0308       .append("isActive", isActive())
0309       .toString();
0310   }
0311 
0312   private static long requestId() {
0313     return Math.abs(UUID.randomUUID().getLeastSignificantBits());
0314   }
0315 
0316   private class StdChannelListener
0317       implements GenericFutureListener<Future<? super Void>> {
0318     final long startTime;
0319     final Object requestId;
0320 
0321     StdChannelListener(Object requestId) {
0322       this.startTime = System.currentTimeMillis();
0323       this.requestId = requestId;
0324     }
0325 
0326     @Override
0327     public void operationComplete(Future<? super Void> future) throws Exception {
0328       if (future.isSuccess()) {
0329         if (logger.isTraceEnabled()) {
0330           long timeTaken = System.currentTimeMillis() - startTime;
0331           logger.trace("Sending request {} to {} took {} ms", requestId,
0332               getRemoteAddress(channel), timeTaken);
0333         }
0334       } else {
0335         String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
0336             getRemoteAddress(channel), future.cause());
0337         logger.error(errorMsg, future.cause());
0338         channel.close();
0339         try {
0340           handleFailure(errorMsg, future.cause());
0341         } catch (Exception e) {
0342           logger.error("Uncaught exception in RPC response callback handler!", e);
0343         }
0344       }
0345     }
0346 
0347     void handleFailure(String errorMsg, Throwable cause) throws Exception {}
0348   }
0349 
0350   private class RpcChannelListener extends StdChannelListener {
0351     final long rpcRequestId;
0352     final RpcResponseCallback callback;
0353 
0354     RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
0355       super("RPC " + rpcRequestId);
0356       this.rpcRequestId = rpcRequestId;
0357       this.callback = callback;
0358     }
0359 
0360     @Override
0361     void handleFailure(String errorMsg, Throwable cause) {
0362       handler.removeRpcRequest(rpcRequestId);
0363       callback.onFailure(new IOException(errorMsg, cause));
0364     }
0365   }
0366 
0367 }