0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.shuffle;
0019
0020 import java.nio.ByteBuffer;
0021 import java.util.Arrays;
0022 import java.util.HashMap;
0023 import java.util.Map;
0024 import java.util.concurrent.CountDownLatch;
0025 import java.util.concurrent.atomic.AtomicReference;
0026 import java.util.function.Function;
0027 import java.util.function.Supplier;
0028
0029 import org.junit.BeforeClass;
0030 import org.junit.Test;
0031
0032 import static org.junit.Assert.*;
0033 import static org.mockito.Mockito.*;
0034
0035 import org.apache.spark.network.TestUtils;
0036 import org.apache.spark.network.TransportContext;
0037 import org.apache.spark.network.buffer.ManagedBuffer;
0038 import org.apache.spark.network.client.ChunkReceivedCallback;
0039 import org.apache.spark.network.client.TransportClient;
0040 import org.apache.spark.network.client.TransportClientBootstrap;
0041 import org.apache.spark.network.client.TransportClientFactory;
0042 import org.apache.spark.network.crypto.AuthClientBootstrap;
0043 import org.apache.spark.network.crypto.AuthServerBootstrap;
0044 import org.apache.spark.network.sasl.SaslClientBootstrap;
0045 import org.apache.spark.network.sasl.SaslServerBootstrap;
0046 import org.apache.spark.network.sasl.SecretKeyHolder;
0047 import org.apache.spark.network.server.OneForOneStreamManager;
0048 import org.apache.spark.network.server.TransportServer;
0049 import org.apache.spark.network.server.TransportServerBootstrap;
0050 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0051 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
0052 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0053 import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
0054 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0055 import org.apache.spark.network.util.MapConfigProvider;
0056 import org.apache.spark.network.util.TransportConf;
0057
0058 public class AppIsolationSuite {
0059
0060
0061
0062 private static final long TIMEOUT_MS = 10_000;
0063
0064 private static SecretKeyHolder secretKeyHolder;
0065 private static TransportConf conf;
0066
0067 @BeforeClass
0068 public static void beforeAll() {
0069 Map<String, String> confMap = new HashMap<>();
0070 confMap.put("spark.network.crypto.enabled", "true");
0071 confMap.put("spark.network.crypto.saslFallback", "false");
0072 conf = new TransportConf("shuffle", new MapConfigProvider(confMap));
0073
0074 secretKeyHolder = mock(SecretKeyHolder.class);
0075 when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
0076 when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
0077 when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
0078 when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
0079 }
0080
0081 @Test
0082 public void testSaslAppIsolation() throws Exception {
0083 testAppIsolation(
0084 () -> new SaslServerBootstrap(conf, secretKeyHolder),
0085 appId -> new SaslClientBootstrap(conf, appId, secretKeyHolder));
0086 }
0087
0088 @Test
0089 public void testAuthEngineAppIsolation() throws Exception {
0090 testAppIsolation(
0091 () -> new AuthServerBootstrap(conf, secretKeyHolder),
0092 appId -> new AuthClientBootstrap(conf, appId, secretKeyHolder));
0093 }
0094
0095 private void testAppIsolation(
0096 Supplier<TransportServerBootstrap> serverBootstrap,
0097 Function<String, TransportClientBootstrap> clientBootstrapFactory) throws Exception {
0098
0099 ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
0100 ExternalBlockHandler blockHandler = new ExternalBlockHandler(
0101 new OneForOneStreamManager(), blockResolver);
0102 TransportServerBootstrap bootstrap = serverBootstrap.get();
0103
0104 try (
0105 TransportContext blockServerContext = new TransportContext(conf, blockHandler);
0106 TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
0107
0108 TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
0109 Arrays.asList(clientBootstrapFactory.apply("app-1")));
0110 TransportClient client1 = clientFactory1.createClient(
0111 TestUtils.getLocalHost(), blockServer.getPort())) {
0112
0113 AtomicReference<Throwable> exception = new AtomicReference<>();
0114
0115 CountDownLatch blockFetchLatch = new CountDownLatch(1);
0116 BlockFetchingListener listener = new BlockFetchingListener() {
0117 @Override
0118 public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
0119 blockFetchLatch.countDown();
0120 }
0121 @Override
0122 public void onBlockFetchFailure(String blockId, Throwable t) {
0123 exception.set(t);
0124 blockFetchLatch.countDown();
0125 }
0126 };
0127
0128 String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
0129 OneForOneBlockFetcher fetcher =
0130 new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
0131 fetcher.start();
0132 blockFetchLatch.await();
0133 checkSecurityException(exception.get());
0134
0135
0136 ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
0137 new String[] { System.getProperty("java.io.tmpdir") }, 1,
0138 "org.apache.spark.shuffle.sort.SortShuffleManager");
0139 RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
0140 client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
0141
0142
0143
0144 OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
0145 ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
0146 StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
0147 long streamId = stream.streamId;
0148
0149 try (
0150
0151
0152 TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
0153 Arrays.asList(clientBootstrapFactory.apply("app-2")));
0154 TransportClient client2 = clientFactory2.createClient(
0155 TestUtils.getLocalHost(), blockServer.getPort())
0156 ) {
0157 CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
0158 ChunkReceivedCallback callback = new ChunkReceivedCallback() {
0159 @Override
0160 public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
0161 chunkReceivedLatch.countDown();
0162 }
0163
0164 @Override
0165 public void onFailure(int chunkIndex, Throwable t) {
0166 exception.set(t);
0167 chunkReceivedLatch.countDown();
0168 }
0169 };
0170
0171 exception.set(null);
0172 client2.fetchChunk(streamId, 0, callback);
0173 chunkReceivedLatch.await();
0174 checkSecurityException(exception.get());
0175 }
0176 }
0177 }
0178
0179 private static void checkSecurityException(Throwable t) {
0180 assertNotNull("No exception was caught.", t);
0181 assertTrue("Expected SecurityException.",
0182 t.getMessage().contains(SecurityException.class.getName()));
0183 }
0184 }