Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.util;
0019 
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 import java.util.Random;
0023 import java.util.concurrent.atomic.AtomicInteger;
0024 
0025 import io.netty.buffer.ByteBuf;
0026 import io.netty.buffer.Unpooled;
0027 import io.netty.channel.ChannelHandlerContext;
0028 import org.junit.AfterClass;
0029 import org.junit.Test;
0030 import org.slf4j.Logger;
0031 import org.slf4j.LoggerFactory;
0032 
0033 import static org.junit.Assert.*;
0034 import static org.mockito.Mockito.*;
0035 
0036 public class TransportFrameDecoderSuite {
0037 
0038   private static final Logger logger = LoggerFactory.getLogger(TransportFrameDecoderSuite.class);
0039   private static Random RND = new Random();
0040 
0041   @AfterClass
0042   public static void cleanup() {
0043     RND = null;
0044   }
0045 
0046   @Test
0047   public void testFrameDecoding() throws Exception {
0048     TransportFrameDecoder decoder = new TransportFrameDecoder();
0049     ChannelHandlerContext ctx = mockChannelHandlerContext();
0050     ByteBuf data = createAndFeedFrames(100, decoder, ctx);
0051     verifyAndCloseDecoder(decoder, ctx, data);
0052   }
0053 
0054   @Test
0055   public void testConsolidationPerf() throws Exception {
0056     long[] testingConsolidateThresholds = new long[] {
0057         ByteUnit.MiB.toBytes(1),
0058         ByteUnit.MiB.toBytes(5),
0059         ByteUnit.MiB.toBytes(10),
0060         ByteUnit.MiB.toBytes(20),
0061         ByteUnit.MiB.toBytes(30),
0062         ByteUnit.MiB.toBytes(50),
0063         ByteUnit.MiB.toBytes(80),
0064         ByteUnit.MiB.toBytes(100),
0065         ByteUnit.MiB.toBytes(300),
0066         ByteUnit.MiB.toBytes(500),
0067         Long.MAX_VALUE };
0068     for (long threshold : testingConsolidateThresholds) {
0069       TransportFrameDecoder decoder = new TransportFrameDecoder(threshold);
0070       ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0071       List<ByteBuf> retained = new ArrayList<>();
0072       when(ctx.fireChannelRead(any())).thenAnswer(in -> {
0073         ByteBuf buf = (ByteBuf) in.getArguments()[0];
0074         retained.add(buf);
0075         return null;
0076       });
0077 
0078       // Testing multiple messages
0079       int numMessages = 3;
0080       long targetBytes = ByteUnit.MiB.toBytes(300);
0081       int pieceBytes = (int) ByteUnit.KiB.toBytes(32);
0082       for (int i = 0; i < numMessages; i++) {
0083         try {
0084           long writtenBytes = 0;
0085           long totalTime = 0;
0086           ByteBuf buf = Unpooled.buffer(8);
0087           buf.writeLong(8 + targetBytes);
0088           decoder.channelRead(ctx, buf);
0089           while (writtenBytes < targetBytes) {
0090             buf = Unpooled.buffer(pieceBytes * 2);
0091             ByteBuf writtenBuf = Unpooled.buffer(pieceBytes).writerIndex(pieceBytes);
0092             buf.writeBytes(writtenBuf);
0093             writtenBuf.release();
0094             long start = System.currentTimeMillis();
0095             decoder.channelRead(ctx, buf);
0096             long elapsedTime = System.currentTimeMillis() - start;
0097             totalTime += elapsedTime;
0098             writtenBytes += pieceBytes;
0099           }
0100           logger.info("Writing 300MiB frame buf with consolidation of threshold " + threshold
0101               + " took " + totalTime + " milis");
0102         } finally {
0103           for (ByteBuf buf : retained) {
0104             release(buf);
0105           }
0106         }
0107       }
0108       long totalBytesGot = 0;
0109       for (ByteBuf buf : retained) {
0110         totalBytesGot += buf.capacity();
0111       }
0112       assertEquals(numMessages, retained.size());
0113       assertEquals(targetBytes * numMessages, totalBytesGot);
0114     }
0115   }
0116 
0117   @Test
0118   public void testInterception() throws Exception {
0119     int interceptedReads = 3;
0120     TransportFrameDecoder decoder = new TransportFrameDecoder();
0121     TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
0122     ChannelHandlerContext ctx = mockChannelHandlerContext();
0123 
0124     byte[] data = new byte[8];
0125     ByteBuf len = Unpooled.copyLong(8 + data.length);
0126     ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
0127 
0128     try {
0129       decoder.setInterceptor(interceptor);
0130       for (int i = 0; i < interceptedReads; i++) {
0131         decoder.channelRead(ctx, dataBuf);
0132         assertEquals(0, dataBuf.refCnt());
0133         dataBuf = Unpooled.wrappedBuffer(data);
0134       }
0135       decoder.channelRead(ctx, len);
0136       decoder.channelRead(ctx, dataBuf);
0137       verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
0138       verify(ctx).fireChannelRead(any(ByteBuf.class));
0139       assertEquals(0, len.refCnt());
0140       assertEquals(0, dataBuf.refCnt());
0141     } finally {
0142       release(len);
0143       release(dataBuf);
0144     }
0145   }
0146 
0147   @Test
0148   public void testRetainedFrames() throws Exception {
0149     TransportFrameDecoder decoder = new TransportFrameDecoder();
0150 
0151     AtomicInteger count = new AtomicInteger();
0152     List<ByteBuf> retained = new ArrayList<>();
0153 
0154     ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0155     when(ctx.fireChannelRead(any())).thenAnswer(in -> {
0156       // Retain a few frames but not others.
0157       ByteBuf buf = (ByteBuf) in.getArguments()[0];
0158       if (count.incrementAndGet() % 2 == 0) {
0159         retained.add(buf);
0160       } else {
0161         buf.release();
0162       }
0163       return null;
0164     });
0165 
0166     ByteBuf data = createAndFeedFrames(100, decoder, ctx);
0167     try {
0168       // Verify all retained buffers are readable.
0169       for (ByteBuf b : retained) {
0170         byte[] tmp = new byte[b.readableBytes()];
0171         b.readBytes(tmp);
0172         b.release();
0173       }
0174       verifyAndCloseDecoder(decoder, ctx, data);
0175     } finally {
0176       for (ByteBuf b : retained) {
0177         release(b);
0178       }
0179     }
0180   }
0181 
0182   @Test
0183   public void testSplitLengthField() throws Exception {
0184     byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
0185     ByteBuf buf = Unpooled.buffer(frame.length + 8);
0186     buf.writeLong(frame.length + 8);
0187     buf.writeBytes(frame);
0188 
0189     TransportFrameDecoder decoder = new TransportFrameDecoder();
0190     ChannelHandlerContext ctx = mockChannelHandlerContext();
0191     try {
0192       decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
0193       verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
0194       decoder.channelRead(ctx, buf);
0195       verify(ctx).fireChannelRead(any(ByteBuf.class));
0196       assertEquals(0, buf.refCnt());
0197     } finally {
0198       decoder.channelInactive(ctx);
0199       release(buf);
0200     }
0201   }
0202 
0203   @Test(expected = IllegalArgumentException.class)
0204   public void testNegativeFrameSize() throws Exception {
0205     testInvalidFrame(-1);
0206   }
0207 
0208   @Test(expected = IllegalArgumentException.class)
0209   public void testEmptyFrame() throws Exception {
0210     // 8 because frame size includes the frame length.
0211     testInvalidFrame(8);
0212   }
0213 
0214   /**
0215    * Creates a number of randomly sized frames and feed them to the given decoder, verifying
0216    * that the frames were read.
0217    */
0218   private ByteBuf createAndFeedFrames(
0219       int frameCount,
0220       TransportFrameDecoder decoder,
0221       ChannelHandlerContext ctx) throws Exception {
0222     ByteBuf data = Unpooled.buffer();
0223     for (int i = 0; i < frameCount; i++) {
0224       byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
0225       data.writeLong(frame.length + 8);
0226       data.writeBytes(frame);
0227     }
0228 
0229     try {
0230       while (data.isReadable()) {
0231         int size = RND.nextInt(4 * 1024) + 256;
0232         decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
0233       }
0234 
0235       verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
0236     } catch (Exception e) {
0237       release(data);
0238       throw e;
0239     }
0240     return data;
0241   }
0242 
0243   private void verifyAndCloseDecoder(
0244       TransportFrameDecoder decoder,
0245       ChannelHandlerContext ctx,
0246       ByteBuf data) throws Exception {
0247     try {
0248       decoder.channelInactive(ctx);
0249       assertTrue("There shouldn't be dangling references to the data.", data.release());
0250     } finally {
0251       release(data);
0252     }
0253   }
0254 
0255   private void testInvalidFrame(long size) throws Exception {
0256     TransportFrameDecoder decoder = new TransportFrameDecoder();
0257     ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0258     ByteBuf frame = Unpooled.copyLong(size);
0259     try {
0260       decoder.channelRead(ctx, frame);
0261     } finally {
0262       release(frame);
0263     }
0264   }
0265 
0266   private ChannelHandlerContext mockChannelHandlerContext() {
0267     ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
0268     when(ctx.fireChannelRead(any())).thenAnswer(in -> {
0269       ByteBuf buf = (ByteBuf) in.getArguments()[0];
0270       buf.release();
0271       return null;
0272     });
0273     return ctx;
0274   }
0275 
0276   private void release(ByteBuf buf) {
0277     if (buf.refCnt() > 0) {
0278       buf.release(buf.refCnt());
0279     }
0280   }
0281 
0282   private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
0283 
0284     private int remainingReads;
0285 
0286     MockInterceptor(int readCount) {
0287       this.remainingReads = readCount;
0288     }
0289 
0290     @Override
0291     public boolean handle(ByteBuf data) throws Exception {
0292       data.readerIndex(data.readerIndex() + data.readableBytes());
0293       assertFalse(data.isReadable());
0294       remainingReads -= 1;
0295       return remainingReads != 0;
0296     }
0297 
0298     @Override
0299     public void exceptionCaught(Throwable cause) throws Exception {
0300 
0301     }
0302 
0303     @Override
0304     public void channelInactive() throws Exception {
0305 
0306     }
0307 
0308   }
0309 
0310 }