Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network;
0019 
0020 import java.util.List;
0021 
0022 import com.google.common.primitives.Ints;
0023 import io.netty.buffer.Unpooled;
0024 import io.netty.channel.ChannelHandlerContext;
0025 import io.netty.channel.FileRegion;
0026 import io.netty.channel.embedded.EmbeddedChannel;
0027 import io.netty.handler.codec.MessageToMessageEncoder;
0028 import org.junit.Test;
0029 
0030 import static org.junit.Assert.assertEquals;
0031 
0032 import org.apache.spark.network.protocol.ChunkFetchFailure;
0033 import org.apache.spark.network.protocol.ChunkFetchRequest;
0034 import org.apache.spark.network.protocol.ChunkFetchSuccess;
0035 import org.apache.spark.network.protocol.Message;
0036 import org.apache.spark.network.protocol.MessageDecoder;
0037 import org.apache.spark.network.protocol.MessageEncoder;
0038 import org.apache.spark.network.protocol.OneWayMessage;
0039 import org.apache.spark.network.protocol.RpcFailure;
0040 import org.apache.spark.network.protocol.RpcRequest;
0041 import org.apache.spark.network.protocol.RpcResponse;
0042 import org.apache.spark.network.protocol.StreamChunkId;
0043 import org.apache.spark.network.protocol.StreamFailure;
0044 import org.apache.spark.network.protocol.StreamRequest;
0045 import org.apache.spark.network.protocol.StreamResponse;
0046 import org.apache.spark.network.util.ByteArrayWritableChannel;
0047 import org.apache.spark.network.util.NettyUtils;
0048 
0049 public class ProtocolSuite {
0050   private void testServerToClient(Message msg) {
0051     EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
0052       MessageEncoder.INSTANCE);
0053     serverChannel.writeOutbound(msg);
0054 
0055     EmbeddedChannel clientChannel = new EmbeddedChannel(
0056         NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
0057 
0058     while (!serverChannel.outboundMessages().isEmpty()) {
0059       clientChannel.writeOneInbound(serverChannel.readOutbound());
0060     }
0061 
0062     assertEquals(1, clientChannel.inboundMessages().size());
0063     assertEquals(msg, clientChannel.readInbound());
0064   }
0065 
0066   private void testClientToServer(Message msg) {
0067     EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
0068       MessageEncoder.INSTANCE);
0069     clientChannel.writeOutbound(msg);
0070 
0071     EmbeddedChannel serverChannel = new EmbeddedChannel(
0072         NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
0073 
0074     while (!clientChannel.outboundMessages().isEmpty()) {
0075       serverChannel.writeOneInbound(clientChannel.readOutbound());
0076     }
0077 
0078     assertEquals(1, serverChannel.inboundMessages().size());
0079     assertEquals(msg, serverChannel.readInbound());
0080   }
0081 
0082   @Test
0083   public void requests() {
0084     testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
0085     testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
0086     testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
0087     testClientToServer(new StreamRequest("abcde"));
0088     testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
0089   }
0090 
0091   @Test
0092   public void responses() {
0093     testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
0094     testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
0095     testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
0096     testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
0097     testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
0098     testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
0099     testServerToClient(new RpcFailure(0, "this is an error"));
0100     testServerToClient(new RpcFailure(0, ""));
0101     // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
0102     // channel and cannot be tested like this.
0103     testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
0104     testServerToClient(new StreamFailure("anId", "this is an error"));
0105   }
0106 
0107   /**
0108    * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
0109    * bytes, but messages, so this is needed so that the frame decoder on the receiving side can
0110    * understand what MessageWithHeader actually contains.
0111    */
0112   private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
0113 
0114     @Override
0115     public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
0116       throws Exception {
0117 
0118       ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
0119       while (in.transferred() < in.count()) {
0120         in.transferTo(channel, in.transferred());
0121       }
0122       out.add(Unpooled.wrappedBuffer(channel.getData()));
0123     }
0124 
0125   }
0126 
0127 }