Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.client;
0019 
0020 import java.io.IOException;
0021 import java.util.Map;
0022 import java.util.Queue;
0023 import java.util.concurrent.ConcurrentHashMap;
0024 import java.util.concurrent.ConcurrentLinkedQueue;
0025 import java.util.concurrent.atomic.AtomicLong;
0026 
0027 import com.google.common.annotations.VisibleForTesting;
0028 import io.netty.channel.Channel;
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.slf4j.Logger;
0032 import org.slf4j.LoggerFactory;
0033 
0034 import org.apache.spark.network.protocol.ChunkFetchFailure;
0035 import org.apache.spark.network.protocol.ChunkFetchSuccess;
0036 import org.apache.spark.network.protocol.ResponseMessage;
0037 import org.apache.spark.network.protocol.RpcFailure;
0038 import org.apache.spark.network.protocol.RpcResponse;
0039 import org.apache.spark.network.protocol.StreamChunkId;
0040 import org.apache.spark.network.protocol.StreamFailure;
0041 import org.apache.spark.network.protocol.StreamResponse;
0042 import org.apache.spark.network.server.MessageHandler;
0043 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0044 import org.apache.spark.network.util.TransportFrameDecoder;
0045 
0046 /**
0047  * Handler that processes server responses, in response to requests issued from a
0048  * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
0049  *
0050  * Concurrency: thread safe and can be called from multiple threads.
0051  */
0052 public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
0053   private static final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
0054 
0055   private final Channel channel;
0056 
0057   private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
0058 
0059   private final Map<Long, RpcResponseCallback> outstandingRpcs;
0060 
0061   private final Queue<Pair<String, StreamCallback>> streamCallbacks;
0062   private volatile boolean streamActive;
0063 
0064   /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
0065   private final AtomicLong timeOfLastRequestNs;
0066 
0067   public TransportResponseHandler(Channel channel) {
0068     this.channel = channel;
0069     this.outstandingFetches = new ConcurrentHashMap<>();
0070     this.outstandingRpcs = new ConcurrentHashMap<>();
0071     this.streamCallbacks = new ConcurrentLinkedQueue<>();
0072     this.timeOfLastRequestNs = new AtomicLong(0);
0073   }
0074 
0075   public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
0076     updateTimeOfLastRequest();
0077     outstandingFetches.put(streamChunkId, callback);
0078   }
0079 
0080   public void removeFetchRequest(StreamChunkId streamChunkId) {
0081     outstandingFetches.remove(streamChunkId);
0082   }
0083 
0084   public void addRpcRequest(long requestId, RpcResponseCallback callback) {
0085     updateTimeOfLastRequest();
0086     outstandingRpcs.put(requestId, callback);
0087   }
0088 
0089   public void removeRpcRequest(long requestId) {
0090     outstandingRpcs.remove(requestId);
0091   }
0092 
0093   public void addStreamCallback(String streamId, StreamCallback callback) {
0094     updateTimeOfLastRequest();
0095     streamCallbacks.offer(ImmutablePair.of(streamId, callback));
0096   }
0097 
0098   @VisibleForTesting
0099   public void deactivateStream() {
0100     streamActive = false;
0101   }
0102 
0103   /**
0104    * Fire the failure callback for all outstanding requests. This is called when we have an
0105    * uncaught exception or pre-mature connection termination.
0106    */
0107   private void failOutstandingRequests(Throwable cause) {
0108     for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
0109       try {
0110         entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
0111       } catch (Exception e) {
0112         logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
0113       }
0114     }
0115     for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
0116       try {
0117         entry.getValue().onFailure(cause);
0118       } catch (Exception e) {
0119         logger.warn("RpcResponseCallback.onFailure throws exception", e);
0120       }
0121     }
0122     for (Pair<String, StreamCallback> entry : streamCallbacks) {
0123       try {
0124         entry.getValue().onFailure(entry.getKey(), cause);
0125       } catch (Exception e) {
0126         logger.warn("StreamCallback.onFailure throws exception", e);
0127       }
0128     }
0129 
0130     // It's OK if new fetches appear, as they will fail immediately.
0131     outstandingFetches.clear();
0132     outstandingRpcs.clear();
0133     streamCallbacks.clear();
0134   }
0135 
0136   @Override
0137   public void channelActive() {
0138   }
0139 
0140   @Override
0141   public void channelInactive() {
0142     if (numOutstandingRequests() > 0) {
0143       String remoteAddress = getRemoteAddress(channel);
0144       logger.error("Still have {} requests outstanding when connection from {} is closed",
0145         numOutstandingRequests(), remoteAddress);
0146       failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
0147     }
0148   }
0149 
0150   @Override
0151   public void exceptionCaught(Throwable cause) {
0152     if (numOutstandingRequests() > 0) {
0153       String remoteAddress = getRemoteAddress(channel);
0154       logger.error("Still have {} requests outstanding when connection from {} is closed",
0155         numOutstandingRequests(), remoteAddress);
0156       failOutstandingRequests(cause);
0157     }
0158   }
0159 
0160   @Override
0161   public void handle(ResponseMessage message) throws Exception {
0162     if (message instanceof ChunkFetchSuccess) {
0163       ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
0164       ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
0165       if (listener == null) {
0166         logger.warn("Ignoring response for block {} from {} since it is not outstanding",
0167           resp.streamChunkId, getRemoteAddress(channel));
0168         resp.body().release();
0169       } else {
0170         outstandingFetches.remove(resp.streamChunkId);
0171         listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
0172         resp.body().release();
0173       }
0174     } else if (message instanceof ChunkFetchFailure) {
0175       ChunkFetchFailure resp = (ChunkFetchFailure) message;
0176       ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
0177       if (listener == null) {
0178         logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
0179           resp.streamChunkId, getRemoteAddress(channel), resp.errorString);
0180       } else {
0181         outstandingFetches.remove(resp.streamChunkId);
0182         listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
0183           "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
0184       }
0185     } else if (message instanceof RpcResponse) {
0186       RpcResponse resp = (RpcResponse) message;
0187       RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
0188       if (listener == null) {
0189         logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
0190           resp.requestId, getRemoteAddress(channel), resp.body().size());
0191       } else {
0192         outstandingRpcs.remove(resp.requestId);
0193         try {
0194           listener.onSuccess(resp.body().nioByteBuffer());
0195         } finally {
0196           resp.body().release();
0197         }
0198       }
0199     } else if (message instanceof RpcFailure) {
0200       RpcFailure resp = (RpcFailure) message;
0201       RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
0202       if (listener == null) {
0203         logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
0204           resp.requestId, getRemoteAddress(channel), resp.errorString);
0205       } else {
0206         outstandingRpcs.remove(resp.requestId);
0207         listener.onFailure(new RuntimeException(resp.errorString));
0208       }
0209     } else if (message instanceof StreamResponse) {
0210       StreamResponse resp = (StreamResponse) message;
0211       Pair<String, StreamCallback> entry = streamCallbacks.poll();
0212       if (entry != null) {
0213         StreamCallback callback = entry.getValue();
0214         if (resp.byteCount > 0) {
0215           StreamInterceptor<ResponseMessage> interceptor = new StreamInterceptor<>(
0216             this, resp.streamId, resp.byteCount, callback);
0217           try {
0218             TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
0219               channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
0220             frameDecoder.setInterceptor(interceptor);
0221             streamActive = true;
0222           } catch (Exception e) {
0223             logger.error("Error installing stream handler.", e);
0224             deactivateStream();
0225           }
0226         } else {
0227           try {
0228             callback.onComplete(resp.streamId);
0229           } catch (Exception e) {
0230             logger.warn("Error in stream handler onComplete().", e);
0231           }
0232         }
0233       } else {
0234         logger.error("Could not find callback for StreamResponse.");
0235       }
0236     } else if (message instanceof StreamFailure) {
0237       StreamFailure resp = (StreamFailure) message;
0238       Pair<String, StreamCallback> entry = streamCallbacks.poll();
0239       if (entry != null) {
0240         StreamCallback callback = entry.getValue();
0241         try {
0242           callback.onFailure(resp.streamId, new RuntimeException(resp.error));
0243         } catch (IOException ioe) {
0244           logger.warn("Error in stream failure handler.", ioe);
0245         }
0246       } else {
0247         logger.warn("Stream failure with unknown callback: {}", resp.error);
0248       }
0249     } else {
0250       throw new IllegalStateException("Unknown response type: " + message.type());
0251     }
0252   }
0253 
0254   /** Returns total number of outstanding requests (fetch requests + rpcs) */
0255   public int numOutstandingRequests() {
0256     return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
0257       (streamActive ? 1 : 0);
0258   }
0259 
0260   /** Returns the time in nanoseconds of when the last request was sent out. */
0261   public long getTimeOfLastRequestNs() {
0262     return timeOfLastRequestNs.get();
0263   }
0264 
0265   /** Updates the time of the last request to the current system time. */
0266   public void updateTimeOfLastRequest() {
0267     timeOfLastRequestNs.set(System.nanoTime());
0268   }
0269 
0270 }