Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.shuffle;
0019 
0020 import java.nio.ByteBuffer;
0021 import java.util.Arrays;
0022 import java.util.HashMap;
0023 import java.util.Map;
0024 import java.util.concurrent.CountDownLatch;
0025 import java.util.concurrent.atomic.AtomicReference;
0026 import java.util.function.Function;
0027 import java.util.function.Supplier;
0028 
0029 import org.junit.BeforeClass;
0030 import org.junit.Test;
0031 
0032 import static org.junit.Assert.*;
0033 import static org.mockito.Mockito.*;
0034 
0035 import org.apache.spark.network.TestUtils;
0036 import org.apache.spark.network.TransportContext;
0037 import org.apache.spark.network.buffer.ManagedBuffer;
0038 import org.apache.spark.network.client.ChunkReceivedCallback;
0039 import org.apache.spark.network.client.TransportClient;
0040 import org.apache.spark.network.client.TransportClientBootstrap;
0041 import org.apache.spark.network.client.TransportClientFactory;
0042 import org.apache.spark.network.crypto.AuthClientBootstrap;
0043 import org.apache.spark.network.crypto.AuthServerBootstrap;
0044 import org.apache.spark.network.sasl.SaslClientBootstrap;
0045 import org.apache.spark.network.sasl.SaslServerBootstrap;
0046 import org.apache.spark.network.sasl.SecretKeyHolder;
0047 import org.apache.spark.network.server.OneForOneStreamManager;
0048 import org.apache.spark.network.server.TransportServer;
0049 import org.apache.spark.network.server.TransportServerBootstrap;
0050 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0051 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
0052 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0053 import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
0054 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0055 import org.apache.spark.network.util.MapConfigProvider;
0056 import org.apache.spark.network.util.TransportConf;
0057 
0058 public class AppIsolationSuite {
0059 
0060   // Use a long timeout to account for slow / overloaded build machines. In the normal case,
0061   // tests should finish way before the timeout expires.
0062   private static final long TIMEOUT_MS = 10_000;
0063 
0064   private static SecretKeyHolder secretKeyHolder;
0065   private static TransportConf conf;
0066 
0067   @BeforeClass
0068   public static void beforeAll() {
0069     Map<String, String> confMap = new HashMap<>();
0070     confMap.put("spark.network.crypto.enabled", "true");
0071     confMap.put("spark.network.crypto.saslFallback", "false");
0072     conf = new TransportConf("shuffle", new MapConfigProvider(confMap));
0073 
0074     secretKeyHolder = mock(SecretKeyHolder.class);
0075     when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
0076     when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
0077     when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
0078     when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
0079   }
0080 
0081   @Test
0082   public void testSaslAppIsolation() throws Exception {
0083     testAppIsolation(
0084       () -> new SaslServerBootstrap(conf, secretKeyHolder),
0085       appId -> new SaslClientBootstrap(conf, appId, secretKeyHolder));
0086   }
0087 
0088   @Test
0089   public void testAuthEngineAppIsolation() throws Exception {
0090     testAppIsolation(
0091       () -> new AuthServerBootstrap(conf, secretKeyHolder),
0092       appId -> new AuthClientBootstrap(conf, appId, secretKeyHolder));
0093   }
0094 
0095   private void testAppIsolation(
0096       Supplier<TransportServerBootstrap> serverBootstrap,
0097       Function<String, TransportClientBootstrap> clientBootstrapFactory) throws Exception {
0098     // Start a new server with the correct RPC handler to serve block data.
0099     ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
0100     ExternalBlockHandler blockHandler = new ExternalBlockHandler(
0101       new OneForOneStreamManager(), blockResolver);
0102     TransportServerBootstrap bootstrap = serverBootstrap.get();
0103 
0104     try (
0105       TransportContext blockServerContext = new TransportContext(conf, blockHandler);
0106       TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
0107       // Create a client, and make a request to fetch blocks from a different app.
0108       TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
0109           Arrays.asList(clientBootstrapFactory.apply("app-1")));
0110       TransportClient client1 = clientFactory1.createClient(
0111           TestUtils.getLocalHost(), blockServer.getPort())) {
0112 
0113       AtomicReference<Throwable> exception = new AtomicReference<>();
0114 
0115       CountDownLatch blockFetchLatch = new CountDownLatch(1);
0116       BlockFetchingListener listener = new BlockFetchingListener() {
0117         @Override
0118         public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
0119           blockFetchLatch.countDown();
0120         }
0121         @Override
0122         public void onBlockFetchFailure(String blockId, Throwable t) {
0123           exception.set(t);
0124           blockFetchLatch.countDown();
0125         }
0126       };
0127 
0128       String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
0129       OneForOneBlockFetcher fetcher =
0130           new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
0131       fetcher.start();
0132       blockFetchLatch.await();
0133       checkSecurityException(exception.get());
0134 
0135       // Register an executor so that the next steps work.
0136       ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
0137         new String[] { System.getProperty("java.io.tmpdir") }, 1,
0138           "org.apache.spark.shuffle.sort.SortShuffleManager");
0139       RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
0140       client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
0141 
0142       // Make a successful request to fetch blocks, which creates a new stream. But do not actually
0143       // fetch any blocks, to keep the stream open.
0144       OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
0145       ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
0146       StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
0147       long streamId = stream.streamId;
0148 
0149       try (
0150         // Create a second client, authenticated with a different app ID, and try to read from
0151         // the stream created for the previous app.
0152         TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
0153             Arrays.asList(clientBootstrapFactory.apply("app-2")));
0154         TransportClient client2 = clientFactory2.createClient(
0155             TestUtils.getLocalHost(), blockServer.getPort())
0156       ) {
0157         CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
0158         ChunkReceivedCallback callback = new ChunkReceivedCallback() {
0159           @Override
0160           public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
0161             chunkReceivedLatch.countDown();
0162           }
0163 
0164           @Override
0165           public void onFailure(int chunkIndex, Throwable t) {
0166             exception.set(t);
0167             chunkReceivedLatch.countDown();
0168           }
0169         };
0170 
0171         exception.set(null);
0172         client2.fetchChunk(streamId, 0, callback);
0173         chunkReceivedLatch.await();
0174         checkSecurityException(exception.get());
0175       }
0176     }
0177   }
0178 
0179   private static void checkSecurityException(Throwable t) {
0180     assertNotNull("No exception was caught.", t);
0181     assertTrue("Expected SecurityException.",
0182       t.getMessage().contains(SecurityException.class.getName()));
0183   }
0184 }