0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.sasl;
0019
0020 import static org.junit.Assert.*;
0021 import static org.mockito.Mockito.*;
0022
0023 import java.io.File;
0024 import java.lang.reflect.Method;
0025 import java.nio.ByteBuffer;
0026 import java.util.ArrayList;
0027 import java.util.Arrays;
0028 import java.util.Collections;
0029 import java.util.List;
0030 import java.util.Map;
0031 import java.util.Random;
0032 import java.util.concurrent.CountDownLatch;
0033 import java.util.concurrent.TimeoutException;
0034 import java.util.concurrent.TimeUnit;
0035 import java.util.concurrent.atomic.AtomicReference;
0036 import javax.security.sasl.SaslException;
0037
0038 import com.google.common.collect.ImmutableMap;
0039 import com.google.common.io.ByteStreams;
0040 import com.google.common.io.Files;
0041 import io.netty.buffer.ByteBuf;
0042 import io.netty.buffer.Unpooled;
0043 import io.netty.channel.Channel;
0044 import io.netty.channel.ChannelHandlerContext;
0045 import io.netty.channel.ChannelOutboundHandlerAdapter;
0046 import io.netty.channel.ChannelPromise;
0047 import org.junit.Test;
0048
0049 import org.apache.spark.network.TestUtils;
0050 import org.apache.spark.network.TransportContext;
0051 import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
0052 import org.apache.spark.network.buffer.ManagedBuffer;
0053 import org.apache.spark.network.client.ChunkReceivedCallback;
0054 import org.apache.spark.network.client.RpcResponseCallback;
0055 import org.apache.spark.network.client.TransportClient;
0056 import org.apache.spark.network.client.TransportClientBootstrap;
0057 import org.apache.spark.network.server.RpcHandler;
0058 import org.apache.spark.network.server.StreamManager;
0059 import org.apache.spark.network.server.TransportServer;
0060 import org.apache.spark.network.server.TransportServerBootstrap;
0061 import org.apache.spark.network.util.ByteArrayWritableChannel;
0062 import org.apache.spark.network.util.JavaUtils;
0063 import org.apache.spark.network.util.MapConfigProvider;
0064 import org.apache.spark.network.util.TransportConf;
0065
0066
0067
0068
0069 public class SparkSaslSuite {
0070
0071
0072 private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
0073 @Override
0074 public String getSaslUser(String appId) {
0075 return "user";
0076 }
0077
0078 @Override
0079 public String getSecretKey(String appId) {
0080 return appId;
0081 }
0082 };
0083
0084 @Test
0085 public void testMatching() {
0086 SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
0087 SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
0088
0089 assertFalse(client.isComplete());
0090 assertFalse(server.isComplete());
0091
0092 byte[] clientMessage = client.firstToken();
0093
0094 while (!client.isComplete()) {
0095 clientMessage = client.response(server.response(clientMessage));
0096 }
0097 assertTrue(server.isComplete());
0098
0099
0100 server.dispose();
0101 assertFalse(server.isComplete());
0102 client.dispose();
0103 assertFalse(client.isComplete());
0104 }
0105
0106 @Test
0107 public void testNonMatching() {
0108 SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
0109 SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
0110
0111 assertFalse(client.isComplete());
0112 assertFalse(server.isComplete());
0113
0114 byte[] clientMessage = client.firstToken();
0115
0116 try {
0117 while (!client.isComplete()) {
0118 clientMessage = client.response(server.response(clientMessage));
0119 }
0120 fail("Should not have completed");
0121 } catch (Exception e) {
0122 assertTrue(e.getMessage().contains("Mismatched response"));
0123 assertFalse(client.isComplete());
0124 assertFalse(server.isComplete());
0125 }
0126 }
0127
0128 @Test
0129 public void testSaslAuthentication() throws Throwable {
0130 testBasicSasl(false);
0131 }
0132
0133 @Test
0134 public void testSaslEncryption() throws Throwable {
0135 testBasicSasl(true);
0136 }
0137
0138 private static void testBasicSasl(boolean encrypt) throws Throwable {
0139 RpcHandler rpcHandler = mock(RpcHandler.class);
0140 doAnswer(invocation -> {
0141 ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
0142 RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
0143 assertEquals("Ping", JavaUtils.bytesToString(message));
0144 cb.onSuccess(JavaUtils.stringToBytes("Pong"));
0145 return null;
0146 })
0147 .when(rpcHandler)
0148 .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
0149
0150 SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
0151 try {
0152 ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
0153 TimeUnit.SECONDS.toMillis(10));
0154 assertEquals("Pong", JavaUtils.bytesToString(response));
0155 } finally {
0156 ctx.close();
0157
0158 Throwable error = null;
0159 long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
0160 while (deadline > System.nanoTime()) {
0161 try {
0162 verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class));
0163 error = null;
0164 break;
0165 } catch (Throwable t) {
0166 error = t;
0167 TimeUnit.MILLISECONDS.sleep(10);
0168 }
0169 }
0170 if (error != null) {
0171 throw error;
0172 }
0173 }
0174 }
0175
0176 @Test
0177 public void testEncryptedMessage() throws Exception {
0178 SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
0179 byte[] data = new byte[1024];
0180 new Random().nextBytes(data);
0181 when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
0182
0183 ByteBuf msg = Unpooled.buffer();
0184 try {
0185 msg.writeBytes(data);
0186
0187
0188
0189
0190 ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
0191
0192 SaslEncryption.EncryptedMessage emsg =
0193 new SaslEncryption.EncryptedMessage(backend, msg, 1024);
0194 long count = emsg.transferTo(channel, emsg.transfered());
0195 assertTrue(count < data.length);
0196 assertTrue(count > 0);
0197
0198
0199 assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
0200
0201
0202
0203 channel.reset();
0204 assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
0205
0206
0207 for (int i = 0; i < data.length / 32 - 2; i++) {
0208 channel.reset();
0209 assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
0210 }
0211
0212 channel.reset();
0213 count = emsg.transferTo(channel, emsg.transfered());
0214 assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
0215 assertEquals(data.length, emsg.transfered());
0216 } finally {
0217 msg.release();
0218 }
0219 }
0220
0221 @Test
0222 public void testEncryptedMessageChunking() throws Exception {
0223 File file = File.createTempFile("sasltest", ".txt");
0224 try {
0225 TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0226
0227 byte[] data = new byte[8 * 1024];
0228 new Random().nextBytes(data);
0229 Files.write(data, file);
0230
0231 SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
0232
0233 when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
0234
0235 FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
0236 SaslEncryption.EncryptedMessage emsg =
0237 new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
0238
0239 ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
0240 while (emsg.transfered() < emsg.count()) {
0241 channel.reset();
0242 emsg.transferTo(channel, emsg.transfered());
0243 }
0244
0245 verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
0246 } finally {
0247 file.delete();
0248 }
0249 }
0250
0251 @Test
0252 public void testFileRegionEncryption() throws Exception {
0253 Map<String, String> testConf = ImmutableMap.of(
0254 "spark.network.sasl.maxEncryptedBlockSize", "1k");
0255
0256 AtomicReference<ManagedBuffer> response = new AtomicReference<>();
0257 File file = File.createTempFile("sasltest", ".txt");
0258 SaslTestCtx ctx = null;
0259 try {
0260 TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
0261 StreamManager sm = mock(StreamManager.class);
0262 when(sm.getChunk(anyLong(), anyInt())).thenAnswer(invocation ->
0263 new FileSegmentManagedBuffer(conf, file, 0, file.length()));
0264
0265 RpcHandler rpcHandler = mock(RpcHandler.class);
0266 when(rpcHandler.getStreamManager()).thenReturn(sm);
0267
0268 byte[] data = new byte[8 * 1024];
0269 new Random().nextBytes(data);
0270 Files.write(data, file);
0271
0272 ctx = new SaslTestCtx(rpcHandler, true, false, testConf);
0273
0274 CountDownLatch lock = new CountDownLatch(1);
0275
0276 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
0277 doAnswer(invocation -> {
0278 response.set((ManagedBuffer) invocation.getArguments()[1]);
0279 response.get().retain();
0280 lock.countDown();
0281 return null;
0282 }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
0283
0284 ctx.client.fetchChunk(0, 0, callback);
0285 lock.await(10, TimeUnit.SECONDS);
0286
0287 verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
0288 verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
0289
0290 byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
0291 assertTrue(Arrays.equals(data, received));
0292 } finally {
0293 file.delete();
0294 if (ctx != null) {
0295 ctx.close();
0296 }
0297 if (response.get() != null) {
0298 response.get().release();
0299 }
0300 }
0301 }
0302
0303 @Test
0304 public void testServerAlwaysEncrypt() throws Exception {
0305 SaslTestCtx ctx = null;
0306 try {
0307 ctx = new SaslTestCtx(mock(RpcHandler.class), false, false,
0308 ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
0309 fail("Should have failed to connect without encryption.");
0310 } catch (Exception e) {
0311 assertTrue(e.getCause() instanceof SaslException);
0312 } finally {
0313 if (ctx != null) {
0314 ctx.close();
0315 }
0316 }
0317 }
0318
0319 @Test
0320 public void testDataEncryptionIsActuallyEnabled() throws Exception {
0321
0322
0323
0324 SaslTestCtx ctx = null;
0325 try {
0326 ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
0327 ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
0328 TimeUnit.SECONDS.toMillis(10));
0329 fail("Should have failed to send RPC to server.");
0330 } catch (Exception e) {
0331 assertFalse(e.getCause() instanceof TimeoutException);
0332 } finally {
0333 if (ctx != null) {
0334 ctx.close();
0335 }
0336 }
0337 }
0338
0339 @Test
0340 public void testRpcHandlerDelegate() throws Exception {
0341
0342
0343 RpcHandler handler = mock(RpcHandler.class);
0344 RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null);
0345
0346 saslHandler.getStreamManager();
0347 verify(handler).getStreamManager();
0348
0349 saslHandler.channelInactive(null);
0350 verify(handler).channelInactive(isNull());
0351
0352 saslHandler.exceptionCaught(null, null);
0353 verify(handler).exceptionCaught(isNull(), isNull());
0354 }
0355
0356 @Test
0357 public void testDelegates() throws Exception {
0358 Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
0359 for (Method m : rpcHandlerMethods) {
0360 Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
0361 assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
0362 }
0363 }
0364
0365 private static class SaslTestCtx {
0366
0367 final TransportClient client;
0368 final TransportServer server;
0369 final TransportContext ctx;
0370
0371 private final boolean encrypt;
0372 private final boolean disableClientEncryption;
0373 private final EncryptionCheckerBootstrap checker;
0374
0375 SaslTestCtx(
0376 RpcHandler rpcHandler,
0377 boolean encrypt,
0378 boolean disableClientEncryption)
0379 throws Exception {
0380
0381 this(rpcHandler, encrypt, disableClientEncryption, Collections.emptyMap());
0382 }
0383
0384 SaslTestCtx(
0385 RpcHandler rpcHandler,
0386 boolean encrypt,
0387 boolean disableClientEncryption,
0388 Map<String, String> extraConf)
0389 throws Exception {
0390
0391 Map<String, String> testConf = ImmutableMap.<String, String>builder()
0392 .putAll(extraConf)
0393 .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt))
0394 .build();
0395 TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
0396
0397 SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
0398 when(keyHolder.getSaslUser(anyString())).thenReturn("user");
0399 when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
0400
0401 this.ctx = new TransportContext(conf, rpcHandler);
0402
0403 this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
0404
0405 this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
0406 checker));
0407
0408 try {
0409 List<TransportClientBootstrap> clientBootstraps = new ArrayList<>();
0410 clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder));
0411 if (disableClientEncryption) {
0412 clientBootstraps.add(new EncryptionDisablerBootstrap());
0413 }
0414
0415 this.client = ctx.createClientFactory(clientBootstraps)
0416 .createClient(TestUtils.getLocalHost(), server.getPort());
0417 } catch (Exception e) {
0418 close();
0419 throw e;
0420 }
0421
0422 this.encrypt = encrypt;
0423 this.disableClientEncryption = disableClientEncryption;
0424 }
0425
0426 void close() {
0427 if (!disableClientEncryption) {
0428 assertEquals(encrypt, checker.foundEncryptionHandler);
0429 }
0430 if (client != null) {
0431 client.close();
0432 }
0433 if (server != null) {
0434 server.close();
0435 }
0436 if (ctx != null) {
0437 ctx.close();
0438 }
0439 }
0440
0441 }
0442
0443 private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
0444 implements TransportServerBootstrap {
0445
0446 boolean foundEncryptionHandler;
0447 String encryptHandlerName;
0448
0449 EncryptionCheckerBootstrap(String encryptHandlerName) {
0450 this.encryptHandlerName = encryptHandlerName;
0451 }
0452
0453 @Override
0454 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
0455 throws Exception {
0456 if (!foundEncryptionHandler) {
0457 foundEncryptionHandler =
0458 ctx.channel().pipeline().get(encryptHandlerName) != null;
0459 }
0460 ctx.write(msg, promise);
0461 }
0462
0463 @Override
0464 public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
0465 channel.pipeline().addFirst("encryptionChecker", this);
0466 return rpcHandler;
0467 }
0468
0469 }
0470
0471 private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
0472
0473 @Override
0474 public void doBootstrap(TransportClient client, Channel channel) {
0475 channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
0476 }
0477
0478 }
0479
0480 }