Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.sasl;
0019 
0020 import static org.junit.Assert.*;
0021 import static org.mockito.Mockito.*;
0022 
0023 import java.io.File;
0024 import java.lang.reflect.Method;
0025 import java.nio.ByteBuffer;
0026 import java.util.ArrayList;
0027 import java.util.Arrays;
0028 import java.util.Collections;
0029 import java.util.List;
0030 import java.util.Map;
0031 import java.util.Random;
0032 import java.util.concurrent.CountDownLatch;
0033 import java.util.concurrent.TimeoutException;
0034 import java.util.concurrent.TimeUnit;
0035 import java.util.concurrent.atomic.AtomicReference;
0036 import javax.security.sasl.SaslException;
0037 
0038 import com.google.common.collect.ImmutableMap;
0039 import com.google.common.io.ByteStreams;
0040 import com.google.common.io.Files;
0041 import io.netty.buffer.ByteBuf;
0042 import io.netty.buffer.Unpooled;
0043 import io.netty.channel.Channel;
0044 import io.netty.channel.ChannelHandlerContext;
0045 import io.netty.channel.ChannelOutboundHandlerAdapter;
0046 import io.netty.channel.ChannelPromise;
0047 import org.junit.Test;
0048 
0049 import org.apache.spark.network.TestUtils;
0050 import org.apache.spark.network.TransportContext;
0051 import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
0052 import org.apache.spark.network.buffer.ManagedBuffer;
0053 import org.apache.spark.network.client.ChunkReceivedCallback;
0054 import org.apache.spark.network.client.RpcResponseCallback;
0055 import org.apache.spark.network.client.TransportClient;
0056 import org.apache.spark.network.client.TransportClientBootstrap;
0057 import org.apache.spark.network.server.RpcHandler;
0058 import org.apache.spark.network.server.StreamManager;
0059 import org.apache.spark.network.server.TransportServer;
0060 import org.apache.spark.network.server.TransportServerBootstrap;
0061 import org.apache.spark.network.util.ByteArrayWritableChannel;
0062 import org.apache.spark.network.util.JavaUtils;
0063 import org.apache.spark.network.util.MapConfigProvider;
0064 import org.apache.spark.network.util.TransportConf;
0065 
0066 /**
0067  * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
0068  */
0069 public class SparkSaslSuite {
0070 
0071   /** Provides a secret key holder which returns secret key == appId */
0072   private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
0073     @Override
0074     public String getSaslUser(String appId) {
0075       return "user";
0076     }
0077 
0078     @Override
0079     public String getSecretKey(String appId) {
0080       return appId;
0081     }
0082   };
0083 
0084   @Test
0085   public void testMatching() {
0086     SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
0087     SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
0088 
0089     assertFalse(client.isComplete());
0090     assertFalse(server.isComplete());
0091 
0092     byte[] clientMessage = client.firstToken();
0093 
0094     while (!client.isComplete()) {
0095       clientMessage = client.response(server.response(clientMessage));
0096     }
0097     assertTrue(server.isComplete());
0098 
0099     // Disposal should invalidate
0100     server.dispose();
0101     assertFalse(server.isComplete());
0102     client.dispose();
0103     assertFalse(client.isComplete());
0104   }
0105 
0106   @Test
0107   public void testNonMatching() {
0108     SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
0109     SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
0110 
0111     assertFalse(client.isComplete());
0112     assertFalse(server.isComplete());
0113 
0114     byte[] clientMessage = client.firstToken();
0115 
0116     try {
0117       while (!client.isComplete()) {
0118         clientMessage = client.response(server.response(clientMessage));
0119       }
0120       fail("Should not have completed");
0121     } catch (Exception e) {
0122       assertTrue(e.getMessage().contains("Mismatched response"));
0123       assertFalse(client.isComplete());
0124       assertFalse(server.isComplete());
0125     }
0126   }
0127 
0128   @Test
0129   public void testSaslAuthentication() throws Throwable {
0130     testBasicSasl(false);
0131   }
0132 
0133   @Test
0134   public void testSaslEncryption() throws Throwable {
0135     testBasicSasl(true);
0136   }
0137 
0138   private static void testBasicSasl(boolean encrypt) throws Throwable {
0139     RpcHandler rpcHandler = mock(RpcHandler.class);
0140     doAnswer(invocation -> {
0141       ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
0142       RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
0143       assertEquals("Ping", JavaUtils.bytesToString(message));
0144       cb.onSuccess(JavaUtils.stringToBytes("Pong"));
0145       return null;
0146     })
0147       .when(rpcHandler)
0148       .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
0149 
0150     SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
0151     try {
0152       ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
0153         TimeUnit.SECONDS.toMillis(10));
0154       assertEquals("Pong", JavaUtils.bytesToString(response));
0155     } finally {
0156       ctx.close();
0157       // There should be 2 terminated events; one for the client, one for the server.
0158       Throwable error = null;
0159       long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
0160       while (deadline > System.nanoTime()) {
0161         try {
0162           verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class));
0163           error = null;
0164           break;
0165         } catch (Throwable t) {
0166           error = t;
0167           TimeUnit.MILLISECONDS.sleep(10);
0168         }
0169       }
0170       if (error != null) {
0171         throw error;
0172       }
0173     }
0174   }
0175 
0176   @Test
0177   public void testEncryptedMessage() throws Exception {
0178     SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
0179     byte[] data = new byte[1024];
0180     new Random().nextBytes(data);
0181     when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
0182 
0183     ByteBuf msg = Unpooled.buffer();
0184     try {
0185       msg.writeBytes(data);
0186 
0187       // Create a channel with a really small buffer compared to the data. This means that on each
0188       // call, the outbound data will not be fully written, so the write() method should return a
0189       // dummy count to keep the channel alive when possible.
0190       ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
0191 
0192       SaslEncryption.EncryptedMessage emsg =
0193         new SaslEncryption.EncryptedMessage(backend, msg, 1024);
0194       long count = emsg.transferTo(channel, emsg.transfered());
0195       assertTrue(count < data.length);
0196       assertTrue(count > 0);
0197 
0198       // Here, the output buffer is full so nothing should be transferred.
0199       assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
0200 
0201       // Now there's room in the buffer, but not enough to transfer all the remaining data,
0202       // so the dummy count should be returned.
0203       channel.reset();
0204       assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
0205 
0206       // Eventually, the whole message should be transferred.
0207       for (int i = 0; i < data.length / 32 - 2; i++) {
0208         channel.reset();
0209         assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
0210       }
0211 
0212       channel.reset();
0213       count = emsg.transferTo(channel, emsg.transfered());
0214       assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
0215       assertEquals(data.length, emsg.transfered());
0216     } finally {
0217       msg.release();
0218     }
0219   }
0220 
0221   @Test
0222   public void testEncryptedMessageChunking() throws Exception {
0223     File file = File.createTempFile("sasltest", ".txt");
0224     try {
0225       TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0226 
0227       byte[] data = new byte[8 * 1024];
0228       new Random().nextBytes(data);
0229       Files.write(data, file);
0230 
0231       SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
0232       // It doesn't really matter what we return here, as long as it's not null.
0233       when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
0234 
0235       FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
0236       SaslEncryption.EncryptedMessage emsg =
0237         new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
0238 
0239       ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
0240       while (emsg.transfered() < emsg.count()) {
0241         channel.reset();
0242         emsg.transferTo(channel, emsg.transfered());
0243       }
0244 
0245       verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
0246     } finally {
0247       file.delete();
0248     }
0249   }
0250 
0251   @Test
0252   public void testFileRegionEncryption() throws Exception {
0253     Map<String, String> testConf = ImmutableMap.of(
0254       "spark.network.sasl.maxEncryptedBlockSize", "1k");
0255 
0256     AtomicReference<ManagedBuffer> response = new AtomicReference<>();
0257     File file = File.createTempFile("sasltest", ".txt");
0258     SaslTestCtx ctx = null;
0259     try {
0260       TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
0261       StreamManager sm = mock(StreamManager.class);
0262       when(sm.getChunk(anyLong(), anyInt())).thenAnswer(invocation ->
0263           new FileSegmentManagedBuffer(conf, file, 0, file.length()));
0264 
0265       RpcHandler rpcHandler = mock(RpcHandler.class);
0266       when(rpcHandler.getStreamManager()).thenReturn(sm);
0267 
0268       byte[] data = new byte[8 * 1024];
0269       new Random().nextBytes(data);
0270       Files.write(data, file);
0271 
0272       ctx = new SaslTestCtx(rpcHandler, true, false, testConf);
0273 
0274       CountDownLatch lock = new CountDownLatch(1);
0275 
0276       ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
0277       doAnswer(invocation -> {
0278         response.set((ManagedBuffer) invocation.getArguments()[1]);
0279         response.get().retain();
0280         lock.countDown();
0281         return null;
0282       }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
0283 
0284       ctx.client.fetchChunk(0, 0, callback);
0285       lock.await(10, TimeUnit.SECONDS);
0286 
0287       verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
0288       verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
0289 
0290       byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
0291       assertTrue(Arrays.equals(data, received));
0292     } finally {
0293       file.delete();
0294       if (ctx != null) {
0295         ctx.close();
0296       }
0297       if (response.get() != null) {
0298         response.get().release();
0299       }
0300     }
0301   }
0302 
0303   @Test
0304   public void testServerAlwaysEncrypt() throws Exception {
0305     SaslTestCtx ctx = null;
0306     try {
0307       ctx = new SaslTestCtx(mock(RpcHandler.class), false, false,
0308         ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
0309       fail("Should have failed to connect without encryption.");
0310     } catch (Exception e) {
0311       assertTrue(e.getCause() instanceof SaslException);
0312     } finally {
0313       if (ctx != null) {
0314         ctx.close();
0315       }
0316     }
0317   }
0318 
0319   @Test
0320   public void testDataEncryptionIsActuallyEnabled() throws Exception {
0321     // This test sets up an encrypted connection but then, using a client bootstrap, removes
0322     // the encryption handler from the client side. This should cause the server to not be
0323     // able to understand RPCs sent to it and thus close the connection.
0324     SaslTestCtx ctx = null;
0325     try {
0326       ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
0327       ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
0328         TimeUnit.SECONDS.toMillis(10));
0329       fail("Should have failed to send RPC to server.");
0330     } catch (Exception e) {
0331       assertFalse(e.getCause() instanceof TimeoutException);
0332     } finally {
0333       if (ctx != null) {
0334         ctx.close();
0335       }
0336     }
0337   }
0338 
0339   @Test
0340   public void testRpcHandlerDelegate() throws Exception {
0341     // Tests all delegates exception for receive(), which is more complicated and already handled
0342     // by all other tests.
0343     RpcHandler handler = mock(RpcHandler.class);
0344     RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null);
0345 
0346     saslHandler.getStreamManager();
0347     verify(handler).getStreamManager();
0348 
0349     saslHandler.channelInactive(null);
0350     verify(handler).channelInactive(isNull());
0351 
0352     saslHandler.exceptionCaught(null, null);
0353     verify(handler).exceptionCaught(isNull(), isNull());
0354   }
0355 
0356   @Test
0357   public void testDelegates() throws Exception {
0358     Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
0359     for (Method m : rpcHandlerMethods) {
0360       Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
0361       assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
0362     }
0363   }
0364 
0365   private static class SaslTestCtx {
0366 
0367     final TransportClient client;
0368     final TransportServer server;
0369     final TransportContext ctx;
0370 
0371     private final boolean encrypt;
0372     private final boolean disableClientEncryption;
0373     private final EncryptionCheckerBootstrap checker;
0374 
0375     SaslTestCtx(
0376         RpcHandler rpcHandler,
0377         boolean encrypt,
0378         boolean disableClientEncryption)
0379       throws Exception {
0380 
0381       this(rpcHandler, encrypt, disableClientEncryption, Collections.emptyMap());
0382     }
0383 
0384     SaslTestCtx(
0385         RpcHandler rpcHandler,
0386         boolean encrypt,
0387         boolean disableClientEncryption,
0388         Map<String, String> extraConf)
0389       throws Exception {
0390 
0391       Map<String, String> testConf = ImmutableMap.<String, String>builder()
0392         .putAll(extraConf)
0393         .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt))
0394         .build();
0395       TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf));
0396 
0397       SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
0398       when(keyHolder.getSaslUser(anyString())).thenReturn("user");
0399       when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
0400 
0401       this.ctx = new TransportContext(conf, rpcHandler);
0402 
0403       this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
0404 
0405       this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
0406         checker));
0407 
0408       try {
0409         List<TransportClientBootstrap> clientBootstraps = new ArrayList<>();
0410         clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder));
0411         if (disableClientEncryption) {
0412           clientBootstraps.add(new EncryptionDisablerBootstrap());
0413         }
0414 
0415         this.client = ctx.createClientFactory(clientBootstraps)
0416           .createClient(TestUtils.getLocalHost(), server.getPort());
0417       } catch (Exception e) {
0418         close();
0419         throw e;
0420       }
0421 
0422       this.encrypt = encrypt;
0423       this.disableClientEncryption = disableClientEncryption;
0424     }
0425 
0426     void close() {
0427       if (!disableClientEncryption) {
0428         assertEquals(encrypt, checker.foundEncryptionHandler);
0429       }
0430       if (client != null) {
0431         client.close();
0432       }
0433       if (server != null) {
0434         server.close();
0435       }
0436       if (ctx != null) {
0437         ctx.close();
0438       }
0439     }
0440 
0441   }
0442 
0443   private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
0444     implements TransportServerBootstrap {
0445 
0446     boolean foundEncryptionHandler;
0447     String encryptHandlerName;
0448 
0449     EncryptionCheckerBootstrap(String encryptHandlerName) {
0450       this.encryptHandlerName = encryptHandlerName;
0451     }
0452 
0453     @Override
0454     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
0455       throws Exception {
0456       if (!foundEncryptionHandler) {
0457         foundEncryptionHandler =
0458           ctx.channel().pipeline().get(encryptHandlerName) != null;
0459       }
0460       ctx.write(msg, promise);
0461     }
0462 
0463     @Override
0464     public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
0465       channel.pipeline().addFirst("encryptionChecker", this);
0466       return rpcHandler;
0467     }
0468 
0469   }
0470 
0471   private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
0472 
0473     @Override
0474     public void doBootstrap(TransportClient client, Channel channel) {
0475       channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
0476     }
0477 
0478   }
0479 
0480 }