0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network;
0019
0020 import java.util.List;
0021
0022 import com.google.common.primitives.Ints;
0023 import io.netty.buffer.Unpooled;
0024 import io.netty.channel.ChannelHandlerContext;
0025 import io.netty.channel.FileRegion;
0026 import io.netty.channel.embedded.EmbeddedChannel;
0027 import io.netty.handler.codec.MessageToMessageEncoder;
0028 import org.junit.Test;
0029
0030 import static org.junit.Assert.assertEquals;
0031
0032 import org.apache.spark.network.protocol.ChunkFetchFailure;
0033 import org.apache.spark.network.protocol.ChunkFetchRequest;
0034 import org.apache.spark.network.protocol.ChunkFetchSuccess;
0035 import org.apache.spark.network.protocol.Message;
0036 import org.apache.spark.network.protocol.MessageDecoder;
0037 import org.apache.spark.network.protocol.MessageEncoder;
0038 import org.apache.spark.network.protocol.OneWayMessage;
0039 import org.apache.spark.network.protocol.RpcFailure;
0040 import org.apache.spark.network.protocol.RpcRequest;
0041 import org.apache.spark.network.protocol.RpcResponse;
0042 import org.apache.spark.network.protocol.StreamChunkId;
0043 import org.apache.spark.network.protocol.StreamFailure;
0044 import org.apache.spark.network.protocol.StreamRequest;
0045 import org.apache.spark.network.protocol.StreamResponse;
0046 import org.apache.spark.network.util.ByteArrayWritableChannel;
0047 import org.apache.spark.network.util.NettyUtils;
0048
0049 public class ProtocolSuite {
0050 private void testServerToClient(Message msg) {
0051 EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
0052 MessageEncoder.INSTANCE);
0053 serverChannel.writeOutbound(msg);
0054
0055 EmbeddedChannel clientChannel = new EmbeddedChannel(
0056 NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
0057
0058 while (!serverChannel.outboundMessages().isEmpty()) {
0059 clientChannel.writeOneInbound(serverChannel.readOutbound());
0060 }
0061
0062 assertEquals(1, clientChannel.inboundMessages().size());
0063 assertEquals(msg, clientChannel.readInbound());
0064 }
0065
0066 private void testClientToServer(Message msg) {
0067 EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
0068 MessageEncoder.INSTANCE);
0069 clientChannel.writeOutbound(msg);
0070
0071 EmbeddedChannel serverChannel = new EmbeddedChannel(
0072 NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
0073
0074 while (!clientChannel.outboundMessages().isEmpty()) {
0075 serverChannel.writeOneInbound(clientChannel.readOutbound());
0076 }
0077
0078 assertEquals(1, serverChannel.inboundMessages().size());
0079 assertEquals(msg, serverChannel.readInbound());
0080 }
0081
0082 @Test
0083 public void requests() {
0084 testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
0085 testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
0086 testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
0087 testClientToServer(new StreamRequest("abcde"));
0088 testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
0089 }
0090
0091 @Test
0092 public void responses() {
0093 testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
0094 testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
0095 testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
0096 testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
0097 testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
0098 testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
0099 testServerToClient(new RpcFailure(0, "this is an error"));
0100 testServerToClient(new RpcFailure(0, ""));
0101
0102
0103 testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
0104 testServerToClient(new StreamFailure("anId", "this is an error"));
0105 }
0106
0107
0108
0109
0110
0111
0112 private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
0113
0114 @Override
0115 public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
0116 throws Exception {
0117
0118 ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
0119 while (in.transferred() < in.count()) {
0120 in.transferTo(channel, in.transferred());
0121 }
0122 out.add(Unpooled.wrappedBuffer(channel.getData()));
0123 }
0124
0125 }
0126
0127 }