Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.shuffle;
0019 
0020 import java.nio.ByteBuffer;
0021 import java.util.Iterator;
0022 
0023 import com.codahale.metrics.Meter;
0024 import com.codahale.metrics.Timer;
0025 import org.junit.Before;
0026 import org.junit.Test;
0027 import org.mockito.ArgumentCaptor;
0028 
0029 import static org.junit.Assert.*;
0030 import static org.mockito.ArgumentMatchers.any;
0031 import static org.mockito.Mockito.*;
0032 
0033 import org.apache.spark.network.buffer.ManagedBuffer;
0034 import org.apache.spark.network.buffer.NioManagedBuffer;
0035 import org.apache.spark.network.client.RpcResponseCallback;
0036 import org.apache.spark.network.client.TransportClient;
0037 import org.apache.spark.network.server.OneForOneStreamManager;
0038 import org.apache.spark.network.server.RpcHandler;
0039 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0040 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
0041 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
0042 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0043 import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
0044 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0045 import org.apache.spark.network.shuffle.protocol.UploadBlock;
0046 
0047 public class ExternalBlockHandlerSuite {
0048   TransportClient client = mock(TransportClient.class);
0049 
0050   OneForOneStreamManager streamManager;
0051   ExternalShuffleBlockResolver blockResolver;
0052   RpcHandler handler;
0053   ManagedBuffer[] blockMarkers = {
0054     new NioManagedBuffer(ByteBuffer.wrap(new byte[3])),
0055     new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
0056   };
0057 
0058   @Before
0059   public void beforeEach() {
0060     streamManager = mock(OneForOneStreamManager.class);
0061     blockResolver = mock(ExternalShuffleBlockResolver.class);
0062     handler = new ExternalBlockHandler(streamManager, blockResolver);
0063   }
0064 
0065   @Test
0066   public void testRegisterExecutor() {
0067     RpcResponseCallback callback = mock(RpcResponseCallback.class);
0068 
0069     ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
0070     ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
0071     handler.receive(client, registerMessage, callback);
0072     verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
0073 
0074     verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
0075     verify(callback, never()).onFailure(any(Throwable.class));
0076     // Verify register executor request latency metrics
0077     Timer registerExecutorRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
0078         .getAllMetrics()
0079         .getMetrics()
0080         .get("registerExecutorRequestLatencyMillis");
0081     assertEquals(1, registerExecutorRequestLatencyMillis.getCount());
0082   }
0083 
0084   @Test
0085   public void testCompatibilityWithOldVersion() {
0086     when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]);
0087     when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]);
0088 
0089     OpenBlocks openBlocks = new OpenBlocks(
0090       "app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" });
0091     checkOpenBlocksReceive(openBlocks, blockMarkers);
0092 
0093     verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0);
0094     verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1);
0095     verifyOpenBlockLatencyMetrics();
0096   }
0097 
0098   @Test
0099   public void testFetchShuffleBlocks() {
0100     when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]);
0101     when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]);
0102 
0103     FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
0104       "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, false);
0105     checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers);
0106 
0107     verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0);
0108     verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1);
0109     verifyOpenBlockLatencyMetrics();
0110   }
0111 
0112   @Test
0113   public void testFetchShuffleBlocksInBatch() {
0114     ManagedBuffer[] batchBlockMarkers = {
0115       new NioManagedBuffer(ByteBuffer.wrap(new byte[10]))
0116     };
0117     when(blockResolver.getContinuousBlocksData(
0118       "app0", "exec1", 0, 0, 0, 1)).thenReturn(batchBlockMarkers[0]);
0119 
0120     FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
0121       "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, true);
0122     checkOpenBlocksReceive(fetchShuffleBlocks, batchBlockMarkers);
0123 
0124     verify(blockResolver, times(1)).getContinuousBlocksData("app0", "exec1", 0, 0, 0, 1);
0125     verifyOpenBlockLatencyMetrics();
0126   }
0127 
0128   @Test
0129   public void testOpenDiskPersistedRDDBlocks() {
0130     when(blockResolver.getRddBlockData("app0", "exec1", 0, 0)).thenReturn(blockMarkers[0]);
0131     when(blockResolver.getRddBlockData("app0", "exec1", 0, 1)).thenReturn(blockMarkers[1]);
0132 
0133     OpenBlocks openBlocks = new OpenBlocks(
0134       "app0", "exec1", new String[] { "rdd_0_0", "rdd_0_1" });
0135     checkOpenBlocksReceive(openBlocks, blockMarkers);
0136 
0137     verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 0);
0138     verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 1);
0139     verifyOpenBlockLatencyMetrics();
0140   }
0141 
0142   @Test
0143   public void testOpenDiskPersistedRDDBlocksWithMissingBlock() {
0144     ManagedBuffer[] blockMarkersWithMissingBlock = {
0145       new NioManagedBuffer(ByteBuffer.wrap(new byte[3])),
0146       null
0147     };
0148     when(blockResolver.getRddBlockData("app0", "exec1", 0, 0))
0149       .thenReturn(blockMarkersWithMissingBlock[0]);
0150     when(blockResolver.getRddBlockData("app0", "exec1", 0, 1))
0151       .thenReturn(null);
0152 
0153     OpenBlocks openBlocks = new OpenBlocks(
0154       "app0", "exec1", new String[] { "rdd_0_0", "rdd_0_1" });
0155     checkOpenBlocksReceive(openBlocks, blockMarkersWithMissingBlock);
0156 
0157     verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 0);
0158     verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 1);
0159   }
0160 
0161   private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] blockMarkers) {
0162     when(client.getClientId()).thenReturn("app0");
0163 
0164     RpcResponseCallback callback = mock(RpcResponseCallback.class);
0165     handler.receive(client, msg.toByteBuffer(), callback);
0166 
0167     ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
0168     verify(callback, times(1)).onSuccess(response.capture());
0169     verify(callback, never()).onFailure(any());
0170 
0171     StreamHandle handle =
0172       (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
0173     assertEquals(blockMarkers.length, handle.numChunks);
0174 
0175     @SuppressWarnings("unchecked")
0176     ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
0177         (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
0178     verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
0179       any());
0180     Iterator<ManagedBuffer> buffers = stream.getValue();
0181     for (ManagedBuffer blockMarker : blockMarkers) {
0182       assertEquals(blockMarker, buffers.next());
0183     }
0184     assertFalse(buffers.hasNext());
0185   }
0186 
0187   private void verifyOpenBlockLatencyMetrics() {
0188     Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
0189         .getAllMetrics()
0190         .getMetrics()
0191         .get("openBlockRequestLatencyMillis");
0192     assertEquals(1, openBlockRequestLatencyMillis.getCount());
0193     // Verify block transfer metrics
0194     Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler)
0195         .getAllMetrics()
0196         .getMetrics()
0197         .get("blockTransferRateBytes");
0198     assertEquals(10, blockTransferRateBytes.getCount());
0199   }
0200 
0201   @Test
0202   public void testBadMessages() {
0203     RpcResponseCallback callback = mock(RpcResponseCallback.class);
0204 
0205     ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 });
0206     try {
0207       handler.receive(client, unserializableMsg, callback);
0208       fail("Should have thrown");
0209     } catch (Exception e) {
0210       // pass
0211     }
0212 
0213     ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1],
0214       new byte[2]).toByteBuffer();
0215     try {
0216       handler.receive(client, unexpectedMsg, callback);
0217       fail("Should have thrown");
0218     } catch (UnsupportedOperationException e) {
0219       // pass
0220     }
0221 
0222     verify(callback, never()).onSuccess(any(ByteBuffer.class));
0223     verify(callback, never()).onFailure(any(Throwable.class));
0224   }
0225 }