0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.util;
0019
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 import java.util.Random;
0023 import java.util.concurrent.atomic.AtomicInteger;
0024
0025 import io.netty.buffer.ByteBuf;
0026 import io.netty.buffer.Unpooled;
0027 import io.netty.channel.ChannelHandlerContext;
0028 import org.junit.AfterClass;
0029 import org.junit.Test;
0030 import org.slf4j.Logger;
0031 import org.slf4j.LoggerFactory;
0032
0033 import static org.junit.Assert.*;
0034 import static org.mockito.Mockito.*;
0035
0036 public class TransportFrameDecoderSuite {
0037
0038 private static final Logger logger = LoggerFactory.getLogger(TransportFrameDecoderSuite.class);
0039 private static Random RND = new Random();
0040
0041 @AfterClass
0042 public static void cleanup() {
0043 RND = null;
0044 }
0045
0046 @Test
0047 public void testFrameDecoding() throws Exception {
0048 TransportFrameDecoder decoder = new TransportFrameDecoder();
0049 ChannelHandlerContext ctx = mockChannelHandlerContext();
0050 ByteBuf data = createAndFeedFrames(100, decoder, ctx);
0051 verifyAndCloseDecoder(decoder, ctx, data);
0052 }
0053
0054 @Test
0055 public void testConsolidationPerf() throws Exception {
0056 long[] testingConsolidateThresholds = new long[] {
0057 ByteUnit.MiB.toBytes(1),
0058 ByteUnit.MiB.toBytes(5),
0059 ByteUnit.MiB.toBytes(10),
0060 ByteUnit.MiB.toBytes(20),
0061 ByteUnit.MiB.toBytes(30),
0062 ByteUnit.MiB.toBytes(50),
0063 ByteUnit.MiB.toBytes(80),
0064 ByteUnit.MiB.toBytes(100),
0065 ByteUnit.MiB.toBytes(300),
0066 ByteUnit.MiB.toBytes(500),
0067 Long.MAX_VALUE };
0068 for (long threshold : testingConsolidateThresholds) {
0069 TransportFrameDecoder decoder = new TransportFrameDecoder(threshold);
0070 ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0071 List<ByteBuf> retained = new ArrayList<>();
0072 when(ctx.fireChannelRead(any())).thenAnswer(in -> {
0073 ByteBuf buf = (ByteBuf) in.getArguments()[0];
0074 retained.add(buf);
0075 return null;
0076 });
0077
0078
0079 int numMessages = 3;
0080 long targetBytes = ByteUnit.MiB.toBytes(300);
0081 int pieceBytes = (int) ByteUnit.KiB.toBytes(32);
0082 for (int i = 0; i < numMessages; i++) {
0083 try {
0084 long writtenBytes = 0;
0085 long totalTime = 0;
0086 ByteBuf buf = Unpooled.buffer(8);
0087 buf.writeLong(8 + targetBytes);
0088 decoder.channelRead(ctx, buf);
0089 while (writtenBytes < targetBytes) {
0090 buf = Unpooled.buffer(pieceBytes * 2);
0091 ByteBuf writtenBuf = Unpooled.buffer(pieceBytes).writerIndex(pieceBytes);
0092 buf.writeBytes(writtenBuf);
0093 writtenBuf.release();
0094 long start = System.currentTimeMillis();
0095 decoder.channelRead(ctx, buf);
0096 long elapsedTime = System.currentTimeMillis() - start;
0097 totalTime += elapsedTime;
0098 writtenBytes += pieceBytes;
0099 }
0100 logger.info("Writing 300MiB frame buf with consolidation of threshold " + threshold
0101 + " took " + totalTime + " milis");
0102 } finally {
0103 for (ByteBuf buf : retained) {
0104 release(buf);
0105 }
0106 }
0107 }
0108 long totalBytesGot = 0;
0109 for (ByteBuf buf : retained) {
0110 totalBytesGot += buf.capacity();
0111 }
0112 assertEquals(numMessages, retained.size());
0113 assertEquals(targetBytes * numMessages, totalBytesGot);
0114 }
0115 }
0116
0117 @Test
0118 public void testInterception() throws Exception {
0119 int interceptedReads = 3;
0120 TransportFrameDecoder decoder = new TransportFrameDecoder();
0121 TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
0122 ChannelHandlerContext ctx = mockChannelHandlerContext();
0123
0124 byte[] data = new byte[8];
0125 ByteBuf len = Unpooled.copyLong(8 + data.length);
0126 ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
0127
0128 try {
0129 decoder.setInterceptor(interceptor);
0130 for (int i = 0; i < interceptedReads; i++) {
0131 decoder.channelRead(ctx, dataBuf);
0132 assertEquals(0, dataBuf.refCnt());
0133 dataBuf = Unpooled.wrappedBuffer(data);
0134 }
0135 decoder.channelRead(ctx, len);
0136 decoder.channelRead(ctx, dataBuf);
0137 verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
0138 verify(ctx).fireChannelRead(any(ByteBuf.class));
0139 assertEquals(0, len.refCnt());
0140 assertEquals(0, dataBuf.refCnt());
0141 } finally {
0142 release(len);
0143 release(dataBuf);
0144 }
0145 }
0146
0147 @Test
0148 public void testRetainedFrames() throws Exception {
0149 TransportFrameDecoder decoder = new TransportFrameDecoder();
0150
0151 AtomicInteger count = new AtomicInteger();
0152 List<ByteBuf> retained = new ArrayList<>();
0153
0154 ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0155 when(ctx.fireChannelRead(any())).thenAnswer(in -> {
0156
0157 ByteBuf buf = (ByteBuf) in.getArguments()[0];
0158 if (count.incrementAndGet() % 2 == 0) {
0159 retained.add(buf);
0160 } else {
0161 buf.release();
0162 }
0163 return null;
0164 });
0165
0166 ByteBuf data = createAndFeedFrames(100, decoder, ctx);
0167 try {
0168
0169 for (ByteBuf b : retained) {
0170 byte[] tmp = new byte[b.readableBytes()];
0171 b.readBytes(tmp);
0172 b.release();
0173 }
0174 verifyAndCloseDecoder(decoder, ctx, data);
0175 } finally {
0176 for (ByteBuf b : retained) {
0177 release(b);
0178 }
0179 }
0180 }
0181
0182 @Test
0183 public void testSplitLengthField() throws Exception {
0184 byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
0185 ByteBuf buf = Unpooled.buffer(frame.length + 8);
0186 buf.writeLong(frame.length + 8);
0187 buf.writeBytes(frame);
0188
0189 TransportFrameDecoder decoder = new TransportFrameDecoder();
0190 ChannelHandlerContext ctx = mockChannelHandlerContext();
0191 try {
0192 decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
0193 verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
0194 decoder.channelRead(ctx, buf);
0195 verify(ctx).fireChannelRead(any(ByteBuf.class));
0196 assertEquals(0, buf.refCnt());
0197 } finally {
0198 decoder.channelInactive(ctx);
0199 release(buf);
0200 }
0201 }
0202
0203 @Test(expected = IllegalArgumentException.class)
0204 public void testNegativeFrameSize() throws Exception {
0205 testInvalidFrame(-1);
0206 }
0207
0208 @Test(expected = IllegalArgumentException.class)
0209 public void testEmptyFrame() throws Exception {
0210
0211 testInvalidFrame(8);
0212 }
0213
0214
0215
0216
0217
0218 private ByteBuf createAndFeedFrames(
0219 int frameCount,
0220 TransportFrameDecoder decoder,
0221 ChannelHandlerContext ctx) throws Exception {
0222 ByteBuf data = Unpooled.buffer();
0223 for (int i = 0; i < frameCount; i++) {
0224 byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
0225 data.writeLong(frame.length + 8);
0226 data.writeBytes(frame);
0227 }
0228
0229 try {
0230 while (data.isReadable()) {
0231 int size = RND.nextInt(4 * 1024) + 256;
0232 decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
0233 }
0234
0235 verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
0236 } catch (Exception e) {
0237 release(data);
0238 throw e;
0239 }
0240 return data;
0241 }
0242
0243 private void verifyAndCloseDecoder(
0244 TransportFrameDecoder decoder,
0245 ChannelHandlerContext ctx,
0246 ByteBuf data) throws Exception {
0247 try {
0248 decoder.channelInactive(ctx);
0249 assertTrue("There shouldn't be dangling references to the data.", data.release());
0250 } finally {
0251 release(data);
0252 }
0253 }
0254
0255 private void testInvalidFrame(long size) throws Exception {
0256 TransportFrameDecoder decoder = new TransportFrameDecoder();
0257 ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0258 ByteBuf frame = Unpooled.copyLong(size);
0259 try {
0260 decoder.channelRead(ctx, frame);
0261 } finally {
0262 release(frame);
0263 }
0264 }
0265
0266 private ChannelHandlerContext mockChannelHandlerContext() {
0267 ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0268 when(ctx.fireChannelRead(any())).thenAnswer(in -> {
0269 ByteBuf buf = (ByteBuf) in.getArguments()[0];
0270 buf.release();
0271 return null;
0272 });
0273 return ctx;
0274 }
0275
0276 private void release(ByteBuf buf) {
0277 if (buf.refCnt() > 0) {
0278 buf.release(buf.refCnt());
0279 }
0280 }
0281
0282 private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
0283
0284 private int remainingReads;
0285
0286 MockInterceptor(int readCount) {
0287 this.remainingReads = readCount;
0288 }
0289
0290 @Override
0291 public boolean handle(ByteBuf data) throws Exception {
0292 data.readerIndex(data.readerIndex() + data.readableBytes());
0293 assertFalse(data.isReadable());
0294 remainingReads -= 1;
0295 return remainingReads != 0;
0296 }
0297
0298 @Override
0299 public void exceptionCaught(Throwable cause) throws Exception {
0300
0301 }
0302
0303 @Override
0304 public void channelInactive() throws Exception {
0305
0306 }
0307
0308 }
0309
0310 }