Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network;
0019 
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 
0023 import io.netty.channel.Channel;
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026 
0027 import static org.mockito.Mockito.*;
0028 
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.apache.spark.network.buffer.ManagedBuffer;
0032 import org.apache.spark.network.client.TransportClient;
0033 import org.apache.spark.network.protocol.*;
0034 import org.apache.spark.network.server.NoOpRpcHandler;
0035 import org.apache.spark.network.server.OneForOneStreamManager;
0036 import org.apache.spark.network.server.RpcHandler;
0037 import org.apache.spark.network.server.TransportRequestHandler;
0038 
0039 public class TransportRequestHandlerSuite {
0040 
0041   @Test
0042   public void handleStreamRequest() throws Exception {
0043     RpcHandler rpcHandler = new NoOpRpcHandler();
0044     OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
0045     Channel channel = mock(Channel.class);
0046     List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs =
0047       new ArrayList<>();
0048     when(channel.writeAndFlush(any()))
0049       .thenAnswer(invocationOnMock0 -> {
0050         Object response = invocationOnMock0.getArguments()[0];
0051         ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
0052         responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
0053         return channelFuture;
0054       });
0055 
0056     // Prepare the stream.
0057     List<ManagedBuffer> managedBuffers = new ArrayList<>();
0058     managedBuffers.add(new TestManagedBuffer(10));
0059     managedBuffers.add(new TestManagedBuffer(20));
0060     managedBuffers.add(null);
0061     managedBuffers.add(new TestManagedBuffer(30));
0062     managedBuffers.add(new TestManagedBuffer(40));
0063     long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
0064 
0065     Assert.assertEquals(1, streamManager.numStreamStates());
0066 
0067     TransportClient reverseClient = mock(TransportClient.class);
0068     TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
0069       rpcHandler, 2L, null);
0070 
0071     RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0));
0072     requestHandler.handle(request0);
0073     Assert.assertEquals(1, responseAndPromisePairs.size());
0074     Assert.assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse);
0075     Assert.assertEquals(managedBuffers.get(0),
0076       ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body());
0077 
0078     RequestMessage request1 = new StreamRequest(String.format("%d_%d", streamId, 1));
0079     requestHandler.handle(request1);
0080     Assert.assertEquals(2, responseAndPromisePairs.size());
0081     Assert.assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse);
0082     Assert.assertEquals(managedBuffers.get(1),
0083       ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body());
0084 
0085     // Finish flushing the response for request0.
0086     responseAndPromisePairs.get(0).getRight().finish(true);
0087 
0088     StreamRequest request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
0089     requestHandler.handle(request2);
0090     Assert.assertEquals(3, responseAndPromisePairs.size());
0091     Assert.assertTrue(responseAndPromisePairs.get(2).getLeft() instanceof StreamFailure);
0092     Assert.assertEquals(String.format("Stream '%s' was not found.", request2.streamId),
0093         ((StreamFailure) (responseAndPromisePairs.get(2).getLeft())).error);
0094 
0095     RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3));
0096     requestHandler.handle(request3);
0097     Assert.assertEquals(4, responseAndPromisePairs.size());
0098     Assert.assertTrue(responseAndPromisePairs.get(3).getLeft() instanceof StreamResponse);
0099     Assert.assertEquals(managedBuffers.get(3),
0100       ((StreamResponse) (responseAndPromisePairs.get(3).getLeft())).body());
0101 
0102     // Request4 will trigger the close of channel, because the number of max chunks being
0103     // transferred is 2;
0104     RequestMessage request4 = new StreamRequest(String.format("%d_%d", streamId, 4));
0105     requestHandler.handle(request4);
0106     verify(channel, times(1)).close();
0107     Assert.assertEquals(4, responseAndPromisePairs.size());
0108 
0109     streamManager.connectionTerminated(channel);
0110     Assert.assertEquals(0, streamManager.numStreamStates());
0111   }
0112 }