0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.client;
0019
0020 import java.nio.ByteBuffer;
0021 import java.nio.channels.ClosedChannelException;
0022
0023 import io.netty.buffer.ByteBuf;
0024
0025 import org.apache.spark.network.protocol.Message;
0026 import org.apache.spark.network.server.MessageHandler;
0027 import org.apache.spark.network.util.TransportFrameDecoder;
0028
0029
0030
0031
0032
0033 public class StreamInterceptor<T extends Message> implements TransportFrameDecoder.Interceptor {
0034
0035 private final MessageHandler<T> handler;
0036 private final String streamId;
0037 private final long byteCount;
0038 private final StreamCallback callback;
0039 private long bytesRead;
0040
0041 public StreamInterceptor(
0042 MessageHandler<T> handler,
0043 String streamId,
0044 long byteCount,
0045 StreamCallback callback) {
0046 this.handler = handler;
0047 this.streamId = streamId;
0048 this.byteCount = byteCount;
0049 this.callback = callback;
0050 this.bytesRead = 0;
0051 }
0052
0053 @Override
0054 public void exceptionCaught(Throwable cause) throws Exception {
0055 deactivateStream();
0056 callback.onFailure(streamId, cause);
0057 }
0058
0059 @Override
0060 public void channelInactive() throws Exception {
0061 deactivateStream();
0062 callback.onFailure(streamId, new ClosedChannelException());
0063 }
0064
0065 private void deactivateStream() {
0066 if (handler instanceof TransportResponseHandler) {
0067
0068
0069 ((TransportResponseHandler) handler).deactivateStream();
0070 }
0071 }
0072
0073 @Override
0074 public boolean handle(ByteBuf buf) throws Exception {
0075 int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
0076 ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer();
0077
0078 int available = nioBuffer.remaining();
0079 callback.onData(streamId, nioBuffer);
0080 bytesRead += available;
0081 if (bytesRead > byteCount) {
0082 RuntimeException re = new IllegalStateException(String.format(
0083 "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
0084 callback.onFailure(streamId, re);
0085 deactivateStream();
0086 throw re;
0087 } else if (bytesRead == byteCount) {
0088 deactivateStream();
0089 callback.onComplete(streamId);
0090 }
0091
0092 return bytesRead != byteCount;
0093 }
0094
0095 }