Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.crypto;
0019 
0020 import java.nio.ByteBuffer;
0021 import java.util.Arrays;
0022 import java.util.List;
0023 import java.util.Map;
0024 
0025 import com.google.common.collect.ImmutableMap;
0026 import io.netty.channel.Channel;
0027 import org.junit.After;
0028 import org.junit.Test;
0029 import static org.junit.Assert.*;
0030 import static org.mockito.Mockito.*;
0031 
0032 import org.apache.spark.network.TestUtils;
0033 import org.apache.spark.network.TransportContext;
0034 import org.apache.spark.network.client.RpcResponseCallback;
0035 import org.apache.spark.network.client.TransportClient;
0036 import org.apache.spark.network.client.TransportClientBootstrap;
0037 import org.apache.spark.network.sasl.SaslServerBootstrap;
0038 import org.apache.spark.network.sasl.SecretKeyHolder;
0039 import org.apache.spark.network.server.RpcHandler;
0040 import org.apache.spark.network.server.StreamManager;
0041 import org.apache.spark.network.server.TransportServer;
0042 import org.apache.spark.network.server.TransportServerBootstrap;
0043 import org.apache.spark.network.util.JavaUtils;
0044 import org.apache.spark.network.util.MapConfigProvider;
0045 import org.apache.spark.network.util.TransportConf;
0046 
0047 public class AuthIntegrationSuite {
0048 
0049   private AuthTestCtx ctx;
0050 
0051   @After
0052   public void cleanUp() throws Exception {
0053     if (ctx != null) {
0054       ctx.close();
0055     }
0056     ctx = null;
0057   }
0058 
0059   @Test
0060   public void testNewAuth() throws Exception {
0061     ctx = new AuthTestCtx();
0062     ctx.createServer("secret");
0063     ctx.createClient("secret");
0064 
0065     ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0066     assertEquals("Pong", JavaUtils.bytesToString(reply));
0067     assertNull(ctx.authRpcHandler.saslHandler);
0068   }
0069 
0070   @Test
0071   public void testAuthFailure() throws Exception {
0072     ctx = new AuthTestCtx();
0073     ctx.createServer("server");
0074 
0075     try {
0076       ctx.createClient("client");
0077       fail("Should have failed to create client.");
0078     } catch (Exception e) {
0079       assertFalse(ctx.authRpcHandler.isAuthenticated());
0080       assertFalse(ctx.serverChannel.isActive());
0081     }
0082   }
0083 
0084   @Test
0085   public void testSaslServerFallback() throws Exception {
0086     ctx = new AuthTestCtx();
0087     ctx.createServer("secret", true);
0088     ctx.createClient("secret", false);
0089 
0090     ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0091     assertEquals("Pong", JavaUtils.bytesToString(reply));
0092     assertNotNull(ctx.authRpcHandler.saslHandler);
0093     assertTrue(ctx.authRpcHandler.isAuthenticated());
0094   }
0095 
0096   @Test
0097   public void testSaslClientFallback() throws Exception {
0098     ctx = new AuthTestCtx();
0099     ctx.createServer("secret", false);
0100     ctx.createClient("secret", true);
0101 
0102     ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0103     assertEquals("Pong", JavaUtils.bytesToString(reply));
0104   }
0105 
0106   @Test
0107   public void testAuthReplay() throws Exception {
0108     // This test covers the case where an attacker replays a challenge message sniffed from the
0109     // network, but doesn't know the actual secret. The server should close the connection as
0110     // soon as a message is sent after authentication is performed. This is emulated by removing
0111     // the client encryption handler after authentication.
0112     ctx = new AuthTestCtx();
0113     ctx.createServer("secret");
0114     ctx.createClient("secret");
0115 
0116     assertNotNull(ctx.client.getChannel().pipeline()
0117       .remove(TransportCipher.ENCRYPTION_HANDLER_NAME));
0118 
0119     try {
0120       ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0121       fail("Should have failed unencrypted RPC.");
0122     } catch (Exception e) {
0123       assertTrue(ctx.authRpcHandler.isAuthenticated());
0124     }
0125   }
0126 
0127   @Test
0128   public void testLargeMessageEncryption() throws Exception {
0129     // Use a big length to create a message that cannot be put into the encryption buffer completely
0130     final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE;
0131     ctx = new AuthTestCtx(new RpcHandler() {
0132       @Override
0133       public void receive(
0134           TransportClient client,
0135           ByteBuffer message,
0136           RpcResponseCallback callback) {
0137         char[] longMessage = new char[testErrorMessageLength];
0138         Arrays.fill(longMessage, 'D');
0139         callback.onFailure(new RuntimeException(new String(longMessage)));
0140       }
0141 
0142       @Override
0143       public StreamManager getStreamManager() {
0144         return null;
0145       }
0146     });
0147     ctx.createServer("secret");
0148     ctx.createClient("secret");
0149 
0150     try {
0151       ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0152       fail("Should have failed unencrypted RPC.");
0153     } catch (Exception e) {
0154       assertTrue(ctx.authRpcHandler.isAuthenticated());
0155       assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
0156       // Verify we receive the complete error message
0157       int messageStart = e.getMessage().indexOf("DDDDD");
0158       int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5;
0159       assertEquals(testErrorMessageLength, messageEnd - messageStart);
0160     }
0161   }
0162 
0163   private class AuthTestCtx {
0164 
0165     private final String appId = "testAppId";
0166     private final TransportConf conf;
0167     private final TransportContext ctx;
0168 
0169     TransportClient client;
0170     TransportServer server;
0171     volatile Channel serverChannel;
0172     volatile AuthRpcHandler authRpcHandler;
0173 
0174     AuthTestCtx() throws Exception {
0175       this(new RpcHandler() {
0176         @Override
0177         public void receive(
0178             TransportClient client,
0179             ByteBuffer message,
0180             RpcResponseCallback callback) {
0181           assertEquals("Ping", JavaUtils.bytesToString(message));
0182           callback.onSuccess(JavaUtils.stringToBytes("Pong"));
0183         }
0184 
0185         @Override
0186         public StreamManager getStreamManager() {
0187           return null;
0188         }
0189       });
0190     }
0191 
0192     AuthTestCtx(RpcHandler rpcHandler) throws Exception {
0193       Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true");
0194       this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
0195       this.ctx = new TransportContext(conf, rpcHandler);
0196     }
0197 
0198     void createServer(String secret) throws Exception {
0199       createServer(secret, true);
0200     }
0201 
0202     void createServer(String secret, boolean enableAes) throws Exception {
0203       TransportServerBootstrap introspector = (channel, rpcHandler) -> {
0204         this.serverChannel = channel;
0205         if (rpcHandler instanceof AuthRpcHandler) {
0206           this.authRpcHandler = (AuthRpcHandler) rpcHandler;
0207         }
0208         return rpcHandler;
0209       };
0210       SecretKeyHolder keyHolder = createKeyHolder(secret);
0211       TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder)
0212         : new SaslServerBootstrap(conf, keyHolder);
0213       this.server = ctx.createServer(Arrays.asList(auth, introspector));
0214     }
0215 
0216     void createClient(String secret) throws Exception {
0217       createClient(secret, true);
0218     }
0219 
0220     void createClient(String secret, boolean enableAes) throws Exception {
0221       TransportConf clientConf = enableAes ? conf
0222         : new TransportConf("rpc", MapConfigProvider.EMPTY);
0223       List<TransportClientBootstrap> bootstraps = Arrays.asList(
0224         new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret)));
0225       this.client = ctx.createClientFactory(bootstraps)
0226         .createClient(TestUtils.getLocalHost(), server.getPort());
0227     }
0228 
0229     void close() {
0230       if (client != null) {
0231         client.close();
0232       }
0233       if (server != null) {
0234         server.close();
0235       }
0236       if (ctx != null) {
0237         ctx.close();
0238       }
0239     }
0240 
0241     private SecretKeyHolder createKeyHolder(String secret) {
0242       SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
0243       when(keyHolder.getSaslUser(anyString())).thenReturn(appId);
0244       when(keyHolder.getSecretKey(anyString())).thenReturn(secret);
0245       return keyHolder;
0246     }
0247 
0248   }
0249 
0250 }