0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.shuffle;
0019
0020 import java.nio.ByteBuffer;
0021 import java.util.HashMap;
0022 import java.util.Iterator;
0023 import java.util.LinkedHashMap;
0024 import java.util.concurrent.atomic.AtomicInteger;
0025
0026 import com.google.common.collect.Maps;
0027 import io.netty.buffer.Unpooled;
0028 import org.junit.Test;
0029
0030 import static org.junit.Assert.assertEquals;
0031 import static org.junit.Assert.fail;
0032 import static org.mockito.ArgumentMatchers.any;
0033 import static org.mockito.ArgumentMatchers.anyInt;
0034 import static org.mockito.ArgumentMatchers.anyLong;
0035 import static org.mockito.ArgumentMatchers.eq;
0036 import static org.mockito.Mockito.doAnswer;
0037 import static org.mockito.Mockito.mock;
0038 import static org.mockito.Mockito.times;
0039 import static org.mockito.Mockito.verify;
0040
0041 import org.apache.spark.network.buffer.ManagedBuffer;
0042 import org.apache.spark.network.buffer.NettyManagedBuffer;
0043 import org.apache.spark.network.buffer.NioManagedBuffer;
0044 import org.apache.spark.network.client.ChunkReceivedCallback;
0045 import org.apache.spark.network.client.RpcResponseCallback;
0046 import org.apache.spark.network.client.TransportClient;
0047 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0048 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
0049 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0050 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0051 import org.apache.spark.network.util.MapConfigProvider;
0052 import org.apache.spark.network.util.TransportConf;
0053
0054 public class OneForOneBlockFetcherSuite {
0055
0056 private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0057
0058 @Test
0059 public void testFetchOne() {
0060 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0061 blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
0062 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0063
0064 BlockFetchingListener listener = fetchBlocks(
0065 blocks,
0066 blockIds,
0067 new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}, false),
0068 conf);
0069
0070 verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
0071 }
0072
0073 @Test
0074 public void testUseOldProtocol() {
0075 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0076 blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
0077 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0078
0079 BlockFetchingListener listener = fetchBlocks(
0080 blocks,
0081 blockIds,
0082 new OpenBlocks("app-id", "exec-id", blockIds),
0083 new TransportConf("shuffle", new MapConfigProvider(
0084 new HashMap<String, String>() {{
0085 put("spark.shuffle.useOldFetchProtocol", "true");
0086 }}
0087 )));
0088
0089 verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
0090 }
0091
0092 @Test
0093 public void testFetchThreeShuffleBlocks() {
0094 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0095 blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0096 blocks.put("shuffle_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
0097 blocks.put("shuffle_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
0098 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0099
0100 BlockFetchingListener listener = fetchBlocks(
0101 blocks,
0102 blockIds,
0103 new FetchShuffleBlocks(
0104 "app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}, false),
0105 conf);
0106
0107 for (int i = 0; i < 3; i ++) {
0108 verify(listener, times(1)).onBlockFetchSuccess(
0109 "shuffle_0_0_" + i, blocks.get("shuffle_0_0_" + i));
0110 }
0111 }
0112
0113 @Test
0114 public void testBatchFetchThreeShuffleBlocks() {
0115 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0116 blocks.put("shuffle_0_0_0_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[58])));
0117 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0118
0119 BlockFetchingListener listener = fetchBlocks(
0120 blocks,
0121 blockIds,
0122 new FetchShuffleBlocks(
0123 "app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 3 }}, true),
0124 conf);
0125
0126 verify(listener, times(1)).onBlockFetchSuccess(
0127 "shuffle_0_0_0_3", blocks.get("shuffle_0_0_0_3"));
0128 }
0129
0130 @Test
0131 public void testFetchThree() {
0132 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0133 blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0134 blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
0135 blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
0136 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0137
0138 BlockFetchingListener listener = fetchBlocks(
0139 blocks,
0140 blockIds,
0141 new OpenBlocks("app-id", "exec-id", blockIds),
0142 conf);
0143
0144 for (int i = 0; i < 3; i ++) {
0145 verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i));
0146 }
0147 }
0148
0149 @Test
0150 public void testFailure() {
0151 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0152 blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0153 blocks.put("b1", null);
0154 blocks.put("b2", null);
0155 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0156
0157 BlockFetchingListener listener = fetchBlocks(
0158 blocks,
0159 blockIds,
0160 new OpenBlocks("app-id", "exec-id", blockIds),
0161 conf);
0162
0163
0164 verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
0165 verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
0166 verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any());
0167 }
0168
0169 @Test
0170 public void testFailureAndSuccess() {
0171 LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0172 blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0173 blocks.put("b1", null);
0174 blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
0175 String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0176
0177 BlockFetchingListener listener = fetchBlocks(
0178 blocks,
0179 blockIds,
0180 new OpenBlocks("app-id", "exec-id", blockIds),
0181 conf);
0182
0183
0184 verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
0185 verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
0186 verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
0187 verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any());
0188 }
0189
0190 @Test
0191 public void testEmptyBlockFetch() {
0192 try {
0193 fetchBlocks(
0194 Maps.newLinkedHashMap(),
0195 new String[] {},
0196 new OpenBlocks("app-id", "exec-id", new String[] {}),
0197 conf);
0198 fail();
0199 } catch (IllegalArgumentException e) {
0200 assertEquals("Zero-sized blockIds array", e.getMessage());
0201 }
0202 }
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212 private static BlockFetchingListener fetchBlocks(
0213 LinkedHashMap<String, ManagedBuffer> blocks,
0214 String[] blockIds,
0215 BlockTransferMessage expectMessage,
0216 TransportConf transportConf) {
0217 TransportClient client = mock(TransportClient.class);
0218 BlockFetchingListener listener = mock(BlockFetchingListener.class);
0219 OneForOneBlockFetcher fetcher =
0220 new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, transportConf);
0221
0222
0223 doAnswer(invocationOnMock -> {
0224 BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(
0225 (ByteBuffer) invocationOnMock.getArguments()[0]);
0226 RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
0227 callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer());
0228 assertEquals(expectMessage, message);
0229 return null;
0230 }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class));
0231
0232
0233 AtomicInteger expectedChunkIndex = new AtomicInteger(0);
0234 Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
0235 doAnswer(invocation -> {
0236 try {
0237 long streamId = (Long) invocation.getArguments()[0];
0238 int myChunkIndex = (Integer) invocation.getArguments()[1];
0239 assertEquals(123, streamId);
0240 assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);
0241
0242 ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2];
0243 ManagedBuffer result = blockIterator.next();
0244 if (result != null) {
0245 callback.onSuccess(myChunkIndex, result);
0246 } else {
0247 callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex));
0248 }
0249 } catch (Exception e) {
0250 e.printStackTrace();
0251 fail("Unexpected failure");
0252 }
0253 return null;
0254 }).when(client).fetchChunk(anyLong(), anyInt(), any());
0255
0256 fetcher.start();
0257 return listener;
0258 }
0259 }