0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.client;
0019
0020 import java.io.IOException;
0021 import java.util.Map;
0022 import java.util.Queue;
0023 import java.util.concurrent.ConcurrentHashMap;
0024 import java.util.concurrent.ConcurrentLinkedQueue;
0025 import java.util.concurrent.atomic.AtomicLong;
0026
0027 import com.google.common.annotations.VisibleForTesting;
0028 import io.netty.channel.Channel;
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.slf4j.Logger;
0032 import org.slf4j.LoggerFactory;
0033
0034 import org.apache.spark.network.protocol.ChunkFetchFailure;
0035 import org.apache.spark.network.protocol.ChunkFetchSuccess;
0036 import org.apache.spark.network.protocol.ResponseMessage;
0037 import org.apache.spark.network.protocol.RpcFailure;
0038 import org.apache.spark.network.protocol.RpcResponse;
0039 import org.apache.spark.network.protocol.StreamChunkId;
0040 import org.apache.spark.network.protocol.StreamFailure;
0041 import org.apache.spark.network.protocol.StreamResponse;
0042 import org.apache.spark.network.server.MessageHandler;
0043 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
0044 import org.apache.spark.network.util.TransportFrameDecoder;
0045
0046
0047
0048
0049
0050
0051
0052 public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
0053 private static final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
0054
0055 private final Channel channel;
0056
0057 private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
0058
0059 private final Map<Long, RpcResponseCallback> outstandingRpcs;
0060
0061 private final Queue<Pair<String, StreamCallback>> streamCallbacks;
0062 private volatile boolean streamActive;
0063
0064
0065 private final AtomicLong timeOfLastRequestNs;
0066
0067 public TransportResponseHandler(Channel channel) {
0068 this.channel = channel;
0069 this.outstandingFetches = new ConcurrentHashMap<>();
0070 this.outstandingRpcs = new ConcurrentHashMap<>();
0071 this.streamCallbacks = new ConcurrentLinkedQueue<>();
0072 this.timeOfLastRequestNs = new AtomicLong(0);
0073 }
0074
0075 public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
0076 updateTimeOfLastRequest();
0077 outstandingFetches.put(streamChunkId, callback);
0078 }
0079
0080 public void removeFetchRequest(StreamChunkId streamChunkId) {
0081 outstandingFetches.remove(streamChunkId);
0082 }
0083
0084 public void addRpcRequest(long requestId, RpcResponseCallback callback) {
0085 updateTimeOfLastRequest();
0086 outstandingRpcs.put(requestId, callback);
0087 }
0088
0089 public void removeRpcRequest(long requestId) {
0090 outstandingRpcs.remove(requestId);
0091 }
0092
0093 public void addStreamCallback(String streamId, StreamCallback callback) {
0094 updateTimeOfLastRequest();
0095 streamCallbacks.offer(ImmutablePair.of(streamId, callback));
0096 }
0097
0098 @VisibleForTesting
0099 public void deactivateStream() {
0100 streamActive = false;
0101 }
0102
0103
0104
0105
0106
0107 private void failOutstandingRequests(Throwable cause) {
0108 for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
0109 try {
0110 entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
0111 } catch (Exception e) {
0112 logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
0113 }
0114 }
0115 for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
0116 try {
0117 entry.getValue().onFailure(cause);
0118 } catch (Exception e) {
0119 logger.warn("RpcResponseCallback.onFailure throws exception", e);
0120 }
0121 }
0122 for (Pair<String, StreamCallback> entry : streamCallbacks) {
0123 try {
0124 entry.getValue().onFailure(entry.getKey(), cause);
0125 } catch (Exception e) {
0126 logger.warn("StreamCallback.onFailure throws exception", e);
0127 }
0128 }
0129
0130
0131 outstandingFetches.clear();
0132 outstandingRpcs.clear();
0133 streamCallbacks.clear();
0134 }
0135
0136 @Override
0137 public void channelActive() {
0138 }
0139
0140 @Override
0141 public void channelInactive() {
0142 if (numOutstandingRequests() > 0) {
0143 String remoteAddress = getRemoteAddress(channel);
0144 logger.error("Still have {} requests outstanding when connection from {} is closed",
0145 numOutstandingRequests(), remoteAddress);
0146 failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
0147 }
0148 }
0149
0150 @Override
0151 public void exceptionCaught(Throwable cause) {
0152 if (numOutstandingRequests() > 0) {
0153 String remoteAddress = getRemoteAddress(channel);
0154 logger.error("Still have {} requests outstanding when connection from {} is closed",
0155 numOutstandingRequests(), remoteAddress);
0156 failOutstandingRequests(cause);
0157 }
0158 }
0159
0160 @Override
0161 public void handle(ResponseMessage message) throws Exception {
0162 if (message instanceof ChunkFetchSuccess) {
0163 ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
0164 ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
0165 if (listener == null) {
0166 logger.warn("Ignoring response for block {} from {} since it is not outstanding",
0167 resp.streamChunkId, getRemoteAddress(channel));
0168 resp.body().release();
0169 } else {
0170 outstandingFetches.remove(resp.streamChunkId);
0171 listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
0172 resp.body().release();
0173 }
0174 } else if (message instanceof ChunkFetchFailure) {
0175 ChunkFetchFailure resp = (ChunkFetchFailure) message;
0176 ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
0177 if (listener == null) {
0178 logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
0179 resp.streamChunkId, getRemoteAddress(channel), resp.errorString);
0180 } else {
0181 outstandingFetches.remove(resp.streamChunkId);
0182 listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
0183 "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
0184 }
0185 } else if (message instanceof RpcResponse) {
0186 RpcResponse resp = (RpcResponse) message;
0187 RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
0188 if (listener == null) {
0189 logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
0190 resp.requestId, getRemoteAddress(channel), resp.body().size());
0191 } else {
0192 outstandingRpcs.remove(resp.requestId);
0193 try {
0194 listener.onSuccess(resp.body().nioByteBuffer());
0195 } finally {
0196 resp.body().release();
0197 }
0198 }
0199 } else if (message instanceof RpcFailure) {
0200 RpcFailure resp = (RpcFailure) message;
0201 RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
0202 if (listener == null) {
0203 logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
0204 resp.requestId, getRemoteAddress(channel), resp.errorString);
0205 } else {
0206 outstandingRpcs.remove(resp.requestId);
0207 listener.onFailure(new RuntimeException(resp.errorString));
0208 }
0209 } else if (message instanceof StreamResponse) {
0210 StreamResponse resp = (StreamResponse) message;
0211 Pair<String, StreamCallback> entry = streamCallbacks.poll();
0212 if (entry != null) {
0213 StreamCallback callback = entry.getValue();
0214 if (resp.byteCount > 0) {
0215 StreamInterceptor<ResponseMessage> interceptor = new StreamInterceptor<>(
0216 this, resp.streamId, resp.byteCount, callback);
0217 try {
0218 TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
0219 channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
0220 frameDecoder.setInterceptor(interceptor);
0221 streamActive = true;
0222 } catch (Exception e) {
0223 logger.error("Error installing stream handler.", e);
0224 deactivateStream();
0225 }
0226 } else {
0227 try {
0228 callback.onComplete(resp.streamId);
0229 } catch (Exception e) {
0230 logger.warn("Error in stream handler onComplete().", e);
0231 }
0232 }
0233 } else {
0234 logger.error("Could not find callback for StreamResponse.");
0235 }
0236 } else if (message instanceof StreamFailure) {
0237 StreamFailure resp = (StreamFailure) message;
0238 Pair<String, StreamCallback> entry = streamCallbacks.poll();
0239 if (entry != null) {
0240 StreamCallback callback = entry.getValue();
0241 try {
0242 callback.onFailure(resp.streamId, new RuntimeException(resp.error));
0243 } catch (IOException ioe) {
0244 logger.warn("Error in stream failure handler.", ioe);
0245 }
0246 } else {
0247 logger.warn("Stream failure with unknown callback: {}", resp.error);
0248 }
0249 } else {
0250 throw new IllegalStateException("Unknown response type: " + message.type());
0251 }
0252 }
0253
0254
0255 public int numOutstandingRequests() {
0256 return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
0257 (streamActive ? 1 : 0);
0258 }
0259
0260
0261 public long getTimeOfLastRequestNs() {
0262 return timeOfLastRequestNs.get();
0263 }
0264
0265
0266 public void updateTimeOfLastRequest() {
0267 timeOfLastRequestNs.set(System.nanoTime());
0268 }
0269
0270 }