Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.shuffle;
0019 
0020 import java.nio.ByteBuffer;
0021 import java.util.HashMap;
0022 import java.util.Iterator;
0023 import java.util.LinkedHashMap;
0024 import java.util.concurrent.atomic.AtomicInteger;
0025 
0026 import com.google.common.collect.Maps;
0027 import io.netty.buffer.Unpooled;
0028 import org.junit.Test;
0029 
0030 import static org.junit.Assert.assertEquals;
0031 import static org.junit.Assert.fail;
0032 import static org.mockito.ArgumentMatchers.any;
0033 import static org.mockito.ArgumentMatchers.anyInt;
0034 import static org.mockito.ArgumentMatchers.anyLong;
0035 import static org.mockito.ArgumentMatchers.eq;
0036 import static org.mockito.Mockito.doAnswer;
0037 import static org.mockito.Mockito.mock;
0038 import static org.mockito.Mockito.times;
0039 import static org.mockito.Mockito.verify;
0040 
0041 import org.apache.spark.network.buffer.ManagedBuffer;
0042 import org.apache.spark.network.buffer.NettyManagedBuffer;
0043 import org.apache.spark.network.buffer.NioManagedBuffer;
0044 import org.apache.spark.network.client.ChunkReceivedCallback;
0045 import org.apache.spark.network.client.RpcResponseCallback;
0046 import org.apache.spark.network.client.TransportClient;
0047 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0048 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
0049 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0050 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0051 import org.apache.spark.network.util.MapConfigProvider;
0052 import org.apache.spark.network.util.TransportConf;
0053 
0054 public class OneForOneBlockFetcherSuite {
0055 
0056   private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0057 
0058   @Test
0059   public void testFetchOne() {
0060     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0061     blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
0062     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0063 
0064     BlockFetchingListener listener = fetchBlocks(
0065       blocks,
0066       blockIds,
0067       new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}, false),
0068       conf);
0069 
0070     verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
0071   }
0072 
0073   @Test
0074   public void testUseOldProtocol() {
0075     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0076     blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
0077     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0078 
0079     BlockFetchingListener listener = fetchBlocks(
0080       blocks,
0081       blockIds,
0082       new OpenBlocks("app-id", "exec-id", blockIds),
0083       new TransportConf("shuffle", new MapConfigProvider(
0084         new HashMap<String, String>() {{
0085           put("spark.shuffle.useOldFetchProtocol", "true");
0086         }}
0087       )));
0088 
0089     verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
0090   }
0091 
0092   @Test
0093   public void testFetchThreeShuffleBlocks() {
0094     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0095     blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0096     blocks.put("shuffle_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
0097     blocks.put("shuffle_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
0098     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0099 
0100     BlockFetchingListener listener = fetchBlocks(
0101       blocks,
0102       blockIds,
0103       new FetchShuffleBlocks(
0104         "app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}, false),
0105       conf);
0106 
0107     for (int i = 0; i < 3; i ++) {
0108       verify(listener, times(1)).onBlockFetchSuccess(
0109         "shuffle_0_0_" + i, blocks.get("shuffle_0_0_" + i));
0110     }
0111   }
0112 
0113   @Test
0114   public void testBatchFetchThreeShuffleBlocks() {
0115     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0116     blocks.put("shuffle_0_0_0_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[58])));
0117     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0118 
0119     BlockFetchingListener listener = fetchBlocks(
0120       blocks,
0121       blockIds,
0122       new FetchShuffleBlocks(
0123         "app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 3 }}, true),
0124       conf);
0125 
0126     verify(listener, times(1)).onBlockFetchSuccess(
0127       "shuffle_0_0_0_3", blocks.get("shuffle_0_0_0_3"));
0128   }
0129 
0130   @Test
0131   public void testFetchThree() {
0132     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0133     blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0134     blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
0135     blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
0136     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0137 
0138     BlockFetchingListener listener = fetchBlocks(
0139       blocks,
0140       blockIds,
0141       new OpenBlocks("app-id", "exec-id", blockIds),
0142       conf);
0143 
0144     for (int i = 0; i < 3; i ++) {
0145       verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i));
0146     }
0147   }
0148 
0149   @Test
0150   public void testFailure() {
0151     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0152     blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0153     blocks.put("b1", null);
0154     blocks.put("b2", null);
0155     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0156 
0157     BlockFetchingListener listener = fetchBlocks(
0158       blocks,
0159       blockIds,
0160       new OpenBlocks("app-id", "exec-id", blockIds),
0161       conf);
0162 
0163     // Each failure will cause a failure to be invoked in all remaining block fetches.
0164     verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
0165     verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
0166     verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any());
0167   }
0168 
0169   @Test
0170   public void testFailureAndSuccess() {
0171     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
0172     blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
0173     blocks.put("b1", null);
0174     blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
0175     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
0176 
0177     BlockFetchingListener listener = fetchBlocks(
0178       blocks,
0179       blockIds,
0180       new OpenBlocks("app-id", "exec-id", blockIds),
0181       conf);
0182 
0183     // We may call both success and failure for the same block.
0184     verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
0185     verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
0186     verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
0187     verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any());
0188   }
0189 
0190   @Test
0191   public void testEmptyBlockFetch() {
0192     try {
0193       fetchBlocks(
0194         Maps.newLinkedHashMap(),
0195         new String[] {},
0196         new OpenBlocks("app-id", "exec-id", new String[] {}),
0197         conf);
0198       fail();
0199     } catch (IllegalArgumentException e) {
0200       assertEquals("Zero-sized blockIds array", e.getMessage());
0201     }
0202   }
0203 
0204   /**
0205    * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
0206    * simply returns the given (BlockId, Block) pairs.
0207    * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order
0208    * that they were inserted in.
0209    *
0210    * If a block's buffer is "null", an exception will be thrown instead.
0211    */
0212   private static BlockFetchingListener fetchBlocks(
0213       LinkedHashMap<String, ManagedBuffer> blocks,
0214       String[] blockIds,
0215       BlockTransferMessage expectMessage,
0216       TransportConf transportConf) {
0217     TransportClient client = mock(TransportClient.class);
0218     BlockFetchingListener listener = mock(BlockFetchingListener.class);
0219     OneForOneBlockFetcher fetcher =
0220       new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, transportConf);
0221 
0222     // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123
0223     doAnswer(invocationOnMock -> {
0224       BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(
0225         (ByteBuffer) invocationOnMock.getArguments()[0]);
0226       RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
0227       callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer());
0228       assertEquals(expectMessage, message);
0229       return null;
0230     }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class));
0231 
0232     // Respond to each chunk request with a single buffer from our blocks array.
0233     AtomicInteger expectedChunkIndex = new AtomicInteger(0);
0234     Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
0235     doAnswer(invocation -> {
0236       try {
0237         long streamId = (Long) invocation.getArguments()[0];
0238         int myChunkIndex = (Integer) invocation.getArguments()[1];
0239         assertEquals(123, streamId);
0240         assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);
0241 
0242         ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2];
0243         ManagedBuffer result = blockIterator.next();
0244         if (result != null) {
0245           callback.onSuccess(myChunkIndex, result);
0246         } else {
0247           callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex));
0248         }
0249       } catch (Exception e) {
0250         e.printStackTrace();
0251         fail("Unexpected failure");
0252       }
0253       return null;
0254     }).when(client).fetchChunk(anyLong(), anyInt(), any());
0255 
0256     fetcher.start();
0257     return listener;
0258   }
0259 }