Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.protocol;
0019 
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.nio.channels.WritableByteChannel;
0023 
0024 import io.netty.buffer.ByteBuf;
0025 import io.netty.buffer.CompositeByteBuf;
0026 import io.netty.buffer.Unpooled;
0027 import org.apache.spark.network.util.AbstractFileRegion;
0028 import org.junit.Test;
0029 import org.mockito.Mockito;
0030 
0031 import static org.junit.Assert.*;
0032 
0033 import org.apache.spark.network.TestManagedBuffer;
0034 import org.apache.spark.network.buffer.ManagedBuffer;
0035 import org.apache.spark.network.buffer.NettyManagedBuffer;
0036 import org.apache.spark.network.util.ByteArrayWritableChannel;
0037 
0038 public class MessageWithHeaderSuite {
0039 
0040   @Test
0041   public void testSingleWrite() throws Exception {
0042     testFileRegionBody(8, 8);
0043   }
0044 
0045   @Test
0046   public void testShortWrite() throws Exception {
0047     testFileRegionBody(8, 1);
0048   }
0049 
0050   @Test
0051   public void testByteBufBody() throws Exception {
0052     testByteBufBody(Unpooled.copyLong(42));
0053   }
0054 
0055   @Test
0056   public void testCompositeByteBufBodySingleBuffer() throws Exception {
0057     ByteBuf header = Unpooled.copyLong(42);
0058     CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
0059     compositeByteBuf.addComponent(true, header);
0060     assertEquals(1, compositeByteBuf.nioBufferCount());
0061     testByteBufBody(compositeByteBuf);
0062   }
0063 
0064   @Test
0065   public void testCompositeByteBufBodyMultipleBuffers() throws Exception {
0066     ByteBuf header = Unpooled.copyLong(42);
0067     CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
0068     compositeByteBuf.addComponent(true, header.retainedSlice(0, 4));
0069     compositeByteBuf.addComponent(true, header.slice(4, 4));
0070     assertEquals(2, compositeByteBuf.nioBufferCount());
0071     testByteBufBody(compositeByteBuf);
0072   }
0073 
0074   /**
0075    * Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header.
0076    *
0077    * @param header the header to use.
0078    * @throws Exception thrown on error.
0079    */
0080   private void testByteBufBody(ByteBuf header) throws Exception {
0081     long expectedHeaderValue = header.getLong(header.readerIndex());
0082     ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
0083     assertEquals(1, header.refCnt());
0084     assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
0085     ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
0086 
0087     Object body = managedBuf.convertToNetty();
0088     assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
0089     assertEquals(1, header.refCnt());
0090 
0091     MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
0092     ByteBuf result = doWrite(msg, 1);
0093     assertEquals(msg.count(), result.readableBytes());
0094     assertEquals(expectedHeaderValue, result.readLong());
0095     assertEquals(84, result.readLong());
0096 
0097     assertTrue(msg.release());
0098     assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
0099     assertEquals(0, header.refCnt());
0100   }
0101 
0102   @Test
0103   public void testDeallocateReleasesManagedBuffer() throws Exception {
0104     ByteBuf header = Unpooled.copyLong(42);
0105     ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
0106     ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
0107     assertEquals(2, body.refCnt());
0108     MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
0109     assertTrue(msg.release());
0110     Mockito.verify(managedBuf, Mockito.times(1)).release();
0111     assertEquals(0, body.refCnt());
0112   }
0113 
0114   private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
0115     ByteBuf header = Unpooled.copyLong(42);
0116     int headerLength = header.readableBytes();
0117     TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
0118     MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
0119 
0120     ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
0121     assertEquals(headerLength + region.count(), result.readableBytes());
0122     assertEquals(42, result.readLong());
0123     for (long i = 0; i < 8; i++) {
0124       assertEquals(i, result.readLong());
0125     }
0126     assertTrue(msg.release());
0127   }
0128 
0129   private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
0130     int writes = 0;
0131     ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
0132     while (msg.transfered() < msg.count()) {
0133       msg.transferTo(channel, msg.transfered());
0134       writes++;
0135     }
0136     assertTrue("Not enough writes!", minExpectedWrites <= writes);
0137     return Unpooled.wrappedBuffer(channel.getData());
0138   }
0139 
0140   private static class TestFileRegion extends AbstractFileRegion {
0141 
0142     private final int writeCount;
0143     private final int writesPerCall;
0144     private int written;
0145 
0146     TestFileRegion(int totalWrites, int writesPerCall) {
0147       this.writeCount = totalWrites;
0148       this.writesPerCall = writesPerCall;
0149     }
0150 
0151     @Override
0152     public long count() {
0153       return 8 * writeCount;
0154     }
0155 
0156     @Override
0157     public long position() {
0158       return 0;
0159     }
0160 
0161     @Override
0162     public long transferred() {
0163       return 8 * written;
0164     }
0165 
0166     @Override
0167     public long transferTo(WritableByteChannel target, long position) throws IOException {
0168       for (int i = 0; i < writesPerCall; i++) {
0169         ByteBuf buf = Unpooled.copyLong((position / 8) + i);
0170         ByteBuffer nio = buf.nioBuffer();
0171         while (nio.remaining() > 0) {
0172           target.write(nio);
0173         }
0174         buf.release();
0175         written++;
0176       }
0177       return 8 * writesPerCall;
0178     }
0179 
0180     @Override
0181     protected void deallocate() {
0182     }
0183 
0184   }
0185 
0186 }