0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network;
0019
0020 import java.util.ArrayList;
0021 import java.util.List;
0022
0023 import io.netty.channel.Channel;
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026
0027 import static org.mockito.Mockito.*;
0028
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.apache.spark.network.buffer.ManagedBuffer;
0032 import org.apache.spark.network.client.TransportClient;
0033 import org.apache.spark.network.protocol.*;
0034 import org.apache.spark.network.server.NoOpRpcHandler;
0035 import org.apache.spark.network.server.OneForOneStreamManager;
0036 import org.apache.spark.network.server.RpcHandler;
0037 import org.apache.spark.network.server.TransportRequestHandler;
0038
0039 public class TransportRequestHandlerSuite {
0040
0041 @Test
0042 public void handleStreamRequest() throws Exception {
0043 RpcHandler rpcHandler = new NoOpRpcHandler();
0044 OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
0045 Channel channel = mock(Channel.class);
0046 List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs =
0047 new ArrayList<>();
0048 when(channel.writeAndFlush(any()))
0049 .thenAnswer(invocationOnMock0 -> {
0050 Object response = invocationOnMock0.getArguments()[0];
0051 ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
0052 responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
0053 return channelFuture;
0054 });
0055
0056
0057 List<ManagedBuffer> managedBuffers = new ArrayList<>();
0058 managedBuffers.add(new TestManagedBuffer(10));
0059 managedBuffers.add(new TestManagedBuffer(20));
0060 managedBuffers.add(null);
0061 managedBuffers.add(new TestManagedBuffer(30));
0062 managedBuffers.add(new TestManagedBuffer(40));
0063 long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
0064
0065 Assert.assertEquals(1, streamManager.numStreamStates());
0066
0067 TransportClient reverseClient = mock(TransportClient.class);
0068 TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
0069 rpcHandler, 2L, null);
0070
0071 RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0));
0072 requestHandler.handle(request0);
0073 Assert.assertEquals(1, responseAndPromisePairs.size());
0074 Assert.assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse);
0075 Assert.assertEquals(managedBuffers.get(0),
0076 ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body());
0077
0078 RequestMessage request1 = new StreamRequest(String.format("%d_%d", streamId, 1));
0079 requestHandler.handle(request1);
0080 Assert.assertEquals(2, responseAndPromisePairs.size());
0081 Assert.assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse);
0082 Assert.assertEquals(managedBuffers.get(1),
0083 ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body());
0084
0085
0086 responseAndPromisePairs.get(0).getRight().finish(true);
0087
0088 StreamRequest request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
0089 requestHandler.handle(request2);
0090 Assert.assertEquals(3, responseAndPromisePairs.size());
0091 Assert.assertTrue(responseAndPromisePairs.get(2).getLeft() instanceof StreamFailure);
0092 Assert.assertEquals(String.format("Stream '%s' was not found.", request2.streamId),
0093 ((StreamFailure) (responseAndPromisePairs.get(2).getLeft())).error);
0094
0095 RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3));
0096 requestHandler.handle(request3);
0097 Assert.assertEquals(4, responseAndPromisePairs.size());
0098 Assert.assertTrue(responseAndPromisePairs.get(3).getLeft() instanceof StreamResponse);
0099 Assert.assertEquals(managedBuffers.get(3),
0100 ((StreamResponse) (responseAndPromisePairs.get(3).getLeft())).body());
0101
0102
0103
0104 RequestMessage request4 = new StreamRequest(String.format("%d_%d", streamId, 4));
0105 requestHandler.handle(request4);
0106 verify(channel, times(1)).close();
0107 Assert.assertEquals(4, responseAndPromisePairs.size());
0108
0109 streamManager.connectionTerminated(channel);
0110 Assert.assertEquals(0, streamManager.numStreamStates());
0111 }
0112 }