0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.protocol;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.nio.channels.WritableByteChannel;
0023
0024 import io.netty.buffer.ByteBuf;
0025 import io.netty.buffer.CompositeByteBuf;
0026 import io.netty.buffer.Unpooled;
0027 import org.apache.spark.network.util.AbstractFileRegion;
0028 import org.junit.Test;
0029 import org.mockito.Mockito;
0030
0031 import static org.junit.Assert.*;
0032
0033 import org.apache.spark.network.TestManagedBuffer;
0034 import org.apache.spark.network.buffer.ManagedBuffer;
0035 import org.apache.spark.network.buffer.NettyManagedBuffer;
0036 import org.apache.spark.network.util.ByteArrayWritableChannel;
0037
0038 public class MessageWithHeaderSuite {
0039
0040 @Test
0041 public void testSingleWrite() throws Exception {
0042 testFileRegionBody(8, 8);
0043 }
0044
0045 @Test
0046 public void testShortWrite() throws Exception {
0047 testFileRegionBody(8, 1);
0048 }
0049
0050 @Test
0051 public void testByteBufBody() throws Exception {
0052 testByteBufBody(Unpooled.copyLong(42));
0053 }
0054
0055 @Test
0056 public void testCompositeByteBufBodySingleBuffer() throws Exception {
0057 ByteBuf header = Unpooled.copyLong(42);
0058 CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
0059 compositeByteBuf.addComponent(true, header);
0060 assertEquals(1, compositeByteBuf.nioBufferCount());
0061 testByteBufBody(compositeByteBuf);
0062 }
0063
0064 @Test
0065 public void testCompositeByteBufBodyMultipleBuffers() throws Exception {
0066 ByteBuf header = Unpooled.copyLong(42);
0067 CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
0068 compositeByteBuf.addComponent(true, header.retainedSlice(0, 4));
0069 compositeByteBuf.addComponent(true, header.slice(4, 4));
0070 assertEquals(2, compositeByteBuf.nioBufferCount());
0071 testByteBufBody(compositeByteBuf);
0072 }
0073
0074
0075
0076
0077
0078
0079
0080 private void testByteBufBody(ByteBuf header) throws Exception {
0081 long expectedHeaderValue = header.getLong(header.readerIndex());
0082 ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
0083 assertEquals(1, header.refCnt());
0084 assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
0085 ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
0086
0087 Object body = managedBuf.convertToNetty();
0088 assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
0089 assertEquals(1, header.refCnt());
0090
0091 MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
0092 ByteBuf result = doWrite(msg, 1);
0093 assertEquals(msg.count(), result.readableBytes());
0094 assertEquals(expectedHeaderValue, result.readLong());
0095 assertEquals(84, result.readLong());
0096
0097 assertTrue(msg.release());
0098 assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
0099 assertEquals(0, header.refCnt());
0100 }
0101
0102 @Test
0103 public void testDeallocateReleasesManagedBuffer() throws Exception {
0104 ByteBuf header = Unpooled.copyLong(42);
0105 ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
0106 ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
0107 assertEquals(2, body.refCnt());
0108 MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
0109 assertTrue(msg.release());
0110 Mockito.verify(managedBuf, Mockito.times(1)).release();
0111 assertEquals(0, body.refCnt());
0112 }
0113
0114 private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
0115 ByteBuf header = Unpooled.copyLong(42);
0116 int headerLength = header.readableBytes();
0117 TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
0118 MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
0119
0120 ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
0121 assertEquals(headerLength + region.count(), result.readableBytes());
0122 assertEquals(42, result.readLong());
0123 for (long i = 0; i < 8; i++) {
0124 assertEquals(i, result.readLong());
0125 }
0126 assertTrue(msg.release());
0127 }
0128
0129 private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
0130 int writes = 0;
0131 ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
0132 while (msg.transfered() < msg.count()) {
0133 msg.transferTo(channel, msg.transfered());
0134 writes++;
0135 }
0136 assertTrue("Not enough writes!", minExpectedWrites <= writes);
0137 return Unpooled.wrappedBuffer(channel.getData());
0138 }
0139
0140 private static class TestFileRegion extends AbstractFileRegion {
0141
0142 private final int writeCount;
0143 private final int writesPerCall;
0144 private int written;
0145
0146 TestFileRegion(int totalWrites, int writesPerCall) {
0147 this.writeCount = totalWrites;
0148 this.writesPerCall = writesPerCall;
0149 }
0150
0151 @Override
0152 public long count() {
0153 return 8 * writeCount;
0154 }
0155
0156 @Override
0157 public long position() {
0158 return 0;
0159 }
0160
0161 @Override
0162 public long transferred() {
0163 return 8 * written;
0164 }
0165
0166 @Override
0167 public long transferTo(WritableByteChannel target, long position) throws IOException {
0168 for (int i = 0; i < writesPerCall; i++) {
0169 ByteBuf buf = Unpooled.copyLong((position / 8) + i);
0170 ByteBuffer nio = buf.nioBuffer();
0171 while (nio.remaining() > 0) {
0172 target.write(nio);
0173 }
0174 buf.release();
0175 written++;
0176 }
0177 return 8 * writesPerCall;
0178 }
0179
0180 @Override
0181 protected void deallocate() {
0182 }
0183
0184 }
0185
0186 }