0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022
0023 import io.netty.channel.Channel;
0024 import io.netty.channel.local.LocalChannel;
0025 import org.junit.Test;
0026
0027 import static org.junit.Assert.assertEquals;
0028 import static org.mockito.Mockito.*;
0029
0030 import org.apache.spark.network.buffer.NioManagedBuffer;
0031 import org.apache.spark.network.client.ChunkReceivedCallback;
0032 import org.apache.spark.network.client.RpcResponseCallback;
0033 import org.apache.spark.network.client.StreamCallback;
0034 import org.apache.spark.network.client.TransportResponseHandler;
0035 import org.apache.spark.network.protocol.ChunkFetchFailure;
0036 import org.apache.spark.network.protocol.ChunkFetchSuccess;
0037 import org.apache.spark.network.protocol.RpcFailure;
0038 import org.apache.spark.network.protocol.RpcResponse;
0039 import org.apache.spark.network.protocol.StreamChunkId;
0040 import org.apache.spark.network.protocol.StreamFailure;
0041 import org.apache.spark.network.protocol.StreamResponse;
0042 import org.apache.spark.network.util.TransportFrameDecoder;
0043
0044 public class TransportResponseHandlerSuite {
0045 @Test
0046 public void handleSuccessfulFetch() throws Exception {
0047 StreamChunkId streamChunkId = new StreamChunkId(1, 0);
0048
0049 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
0050 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
0051 handler.addFetchRequest(streamChunkId, callback);
0052 assertEquals(1, handler.numOutstandingRequests());
0053
0054 handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
0055 verify(callback, times(1)).onSuccess(eq(0), any());
0056 assertEquals(0, handler.numOutstandingRequests());
0057 }
0058
0059 @Test
0060 public void handleFailedFetch() throws Exception {
0061 StreamChunkId streamChunkId = new StreamChunkId(1, 0);
0062 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
0063 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
0064 handler.addFetchRequest(streamChunkId, callback);
0065 assertEquals(1, handler.numOutstandingRequests());
0066
0067 handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
0068 verify(callback, times(1)).onFailure(eq(0), any());
0069 assertEquals(0, handler.numOutstandingRequests());
0070 }
0071
0072 @Test
0073 public void clearAllOutstandingRequests() throws Exception {
0074 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
0075 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
0076 handler.addFetchRequest(new StreamChunkId(1, 0), callback);
0077 handler.addFetchRequest(new StreamChunkId(1, 1), callback);
0078 handler.addFetchRequest(new StreamChunkId(1, 2), callback);
0079 assertEquals(3, handler.numOutstandingRequests());
0080
0081 handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
0082 handler.exceptionCaught(new Exception("duh duh duhhhh"));
0083
0084
0085 verify(callback, times(1)).onSuccess(eq(0), any());
0086 verify(callback, times(1)).onFailure(eq(1), any());
0087 verify(callback, times(1)).onFailure(eq(2), any());
0088 assertEquals(0, handler.numOutstandingRequests());
0089 }
0090
0091 @Test
0092 public void handleSuccessfulRPC() throws Exception {
0093 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
0094 RpcResponseCallback callback = mock(RpcResponseCallback.class);
0095 handler.addRpcRequest(12345, callback);
0096 assertEquals(1, handler.numOutstandingRequests());
0097
0098
0099 handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7))));
0100 assertEquals(1, handler.numOutstandingRequests());
0101
0102 ByteBuffer resp = ByteBuffer.allocate(10);
0103 handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
0104 verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
0105 assertEquals(0, handler.numOutstandingRequests());
0106 }
0107
0108 @Test
0109 public void handleFailedRPC() throws Exception {
0110 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
0111 RpcResponseCallback callback = mock(RpcResponseCallback.class);
0112 handler.addRpcRequest(12345, callback);
0113 assertEquals(1, handler.numOutstandingRequests());
0114
0115 handler.handle(new RpcFailure(54321, "uh-oh!"));
0116 assertEquals(1, handler.numOutstandingRequests());
0117
0118 handler.handle(new RpcFailure(12345, "oh no"));
0119 verify(callback, times(1)).onFailure(any());
0120 assertEquals(0, handler.numOutstandingRequests());
0121 }
0122
0123 @Test
0124 public void testActiveStreams() throws Exception {
0125 Channel c = new LocalChannel();
0126 c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
0127 TransportResponseHandler handler = new TransportResponseHandler(c);
0128
0129 StreamResponse response = new StreamResponse("stream", 1234L, null);
0130 StreamCallback cb = mock(StreamCallback.class);
0131 handler.addStreamCallback("stream", cb);
0132 assertEquals(1, handler.numOutstandingRequests());
0133 handler.handle(response);
0134 assertEquals(1, handler.numOutstandingRequests());
0135 handler.deactivateStream();
0136 assertEquals(0, handler.numOutstandingRequests());
0137
0138 StreamFailure failure = new StreamFailure("stream", "uh-oh");
0139 handler.addStreamCallback("stream", cb);
0140 assertEquals(1, handler.numOutstandingRequests());
0141 handler.handle(failure);
0142 assertEquals(0, handler.numOutstandingRequests());
0143 }
0144
0145 @Test
0146 public void failOutstandingStreamCallbackOnClose() throws Exception {
0147 Channel c = new LocalChannel();
0148 c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
0149 TransportResponseHandler handler = new TransportResponseHandler(c);
0150
0151 StreamCallback cb = mock(StreamCallback.class);
0152 handler.addStreamCallback("stream-1", cb);
0153 handler.channelInactive();
0154
0155 verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
0156 }
0157
0158 @Test
0159 public void failOutstandingStreamCallbackOnException() throws Exception {
0160 Channel c = new LocalChannel();
0161 c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
0162 TransportResponseHandler handler = new TransportResponseHandler(c);
0163
0164 StreamCallback cb = mock(StreamCallback.class);
0165 handler.addStreamCallback("stream-1", cb);
0166 handler.exceptionCaught(new IOException("Oops!"));
0167
0168 verify(cb).onFailure(eq("stream-1"), isA(IOException.class));
0169 }
0170 }