0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.crypto;
0019
0020 import java.nio.ByteBuffer;
0021 import java.util.Arrays;
0022 import java.util.List;
0023 import java.util.Map;
0024
0025 import com.google.common.collect.ImmutableMap;
0026 import io.netty.channel.Channel;
0027 import org.junit.After;
0028 import org.junit.Test;
0029 import static org.junit.Assert.*;
0030 import static org.mockito.Mockito.*;
0031
0032 import org.apache.spark.network.TestUtils;
0033 import org.apache.spark.network.TransportContext;
0034 import org.apache.spark.network.client.RpcResponseCallback;
0035 import org.apache.spark.network.client.TransportClient;
0036 import org.apache.spark.network.client.TransportClientBootstrap;
0037 import org.apache.spark.network.sasl.SaslServerBootstrap;
0038 import org.apache.spark.network.sasl.SecretKeyHolder;
0039 import org.apache.spark.network.server.RpcHandler;
0040 import org.apache.spark.network.server.StreamManager;
0041 import org.apache.spark.network.server.TransportServer;
0042 import org.apache.spark.network.server.TransportServerBootstrap;
0043 import org.apache.spark.network.util.JavaUtils;
0044 import org.apache.spark.network.util.MapConfigProvider;
0045 import org.apache.spark.network.util.TransportConf;
0046
0047 public class AuthIntegrationSuite {
0048
0049 private AuthTestCtx ctx;
0050
0051 @After
0052 public void cleanUp() throws Exception {
0053 if (ctx != null) {
0054 ctx.close();
0055 }
0056 ctx = null;
0057 }
0058
0059 @Test
0060 public void testNewAuth() throws Exception {
0061 ctx = new AuthTestCtx();
0062 ctx.createServer("secret");
0063 ctx.createClient("secret");
0064
0065 ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0066 assertEquals("Pong", JavaUtils.bytesToString(reply));
0067 assertNull(ctx.authRpcHandler.saslHandler);
0068 }
0069
0070 @Test
0071 public void testAuthFailure() throws Exception {
0072 ctx = new AuthTestCtx();
0073 ctx.createServer("server");
0074
0075 try {
0076 ctx.createClient("client");
0077 fail("Should have failed to create client.");
0078 } catch (Exception e) {
0079 assertFalse(ctx.authRpcHandler.isAuthenticated());
0080 assertFalse(ctx.serverChannel.isActive());
0081 }
0082 }
0083
0084 @Test
0085 public void testSaslServerFallback() throws Exception {
0086 ctx = new AuthTestCtx();
0087 ctx.createServer("secret", true);
0088 ctx.createClient("secret", false);
0089
0090 ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0091 assertEquals("Pong", JavaUtils.bytesToString(reply));
0092 assertNotNull(ctx.authRpcHandler.saslHandler);
0093 assertTrue(ctx.authRpcHandler.isAuthenticated());
0094 }
0095
0096 @Test
0097 public void testSaslClientFallback() throws Exception {
0098 ctx = new AuthTestCtx();
0099 ctx.createServer("secret", false);
0100 ctx.createClient("secret", true);
0101
0102 ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0103 assertEquals("Pong", JavaUtils.bytesToString(reply));
0104 }
0105
0106 @Test
0107 public void testAuthReplay() throws Exception {
0108
0109
0110
0111
0112 ctx = new AuthTestCtx();
0113 ctx.createServer("secret");
0114 ctx.createClient("secret");
0115
0116 assertNotNull(ctx.client.getChannel().pipeline()
0117 .remove(TransportCipher.ENCRYPTION_HANDLER_NAME));
0118
0119 try {
0120 ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0121 fail("Should have failed unencrypted RPC.");
0122 } catch (Exception e) {
0123 assertTrue(ctx.authRpcHandler.isAuthenticated());
0124 }
0125 }
0126
0127 @Test
0128 public void testLargeMessageEncryption() throws Exception {
0129
0130 final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE;
0131 ctx = new AuthTestCtx(new RpcHandler() {
0132 @Override
0133 public void receive(
0134 TransportClient client,
0135 ByteBuffer message,
0136 RpcResponseCallback callback) {
0137 char[] longMessage = new char[testErrorMessageLength];
0138 Arrays.fill(longMessage, 'D');
0139 callback.onFailure(new RuntimeException(new String(longMessage)));
0140 }
0141
0142 @Override
0143 public StreamManager getStreamManager() {
0144 return null;
0145 }
0146 });
0147 ctx.createServer("secret");
0148 ctx.createClient("secret");
0149
0150 try {
0151 ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
0152 fail("Should have failed unencrypted RPC.");
0153 } catch (Exception e) {
0154 assertTrue(ctx.authRpcHandler.isAuthenticated());
0155 assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
0156
0157 int messageStart = e.getMessage().indexOf("DDDDD");
0158 int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5;
0159 assertEquals(testErrorMessageLength, messageEnd - messageStart);
0160 }
0161 }
0162
0163 private class AuthTestCtx {
0164
0165 private final String appId = "testAppId";
0166 private final TransportConf conf;
0167 private final TransportContext ctx;
0168
0169 TransportClient client;
0170 TransportServer server;
0171 volatile Channel serverChannel;
0172 volatile AuthRpcHandler authRpcHandler;
0173
0174 AuthTestCtx() throws Exception {
0175 this(new RpcHandler() {
0176 @Override
0177 public void receive(
0178 TransportClient client,
0179 ByteBuffer message,
0180 RpcResponseCallback callback) {
0181 assertEquals("Ping", JavaUtils.bytesToString(message));
0182 callback.onSuccess(JavaUtils.stringToBytes("Pong"));
0183 }
0184
0185 @Override
0186 public StreamManager getStreamManager() {
0187 return null;
0188 }
0189 });
0190 }
0191
0192 AuthTestCtx(RpcHandler rpcHandler) throws Exception {
0193 Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true");
0194 this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
0195 this.ctx = new TransportContext(conf, rpcHandler);
0196 }
0197
0198 void createServer(String secret) throws Exception {
0199 createServer(secret, true);
0200 }
0201
0202 void createServer(String secret, boolean enableAes) throws Exception {
0203 TransportServerBootstrap introspector = (channel, rpcHandler) -> {
0204 this.serverChannel = channel;
0205 if (rpcHandler instanceof AuthRpcHandler) {
0206 this.authRpcHandler = (AuthRpcHandler) rpcHandler;
0207 }
0208 return rpcHandler;
0209 };
0210 SecretKeyHolder keyHolder = createKeyHolder(secret);
0211 TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder)
0212 : new SaslServerBootstrap(conf, keyHolder);
0213 this.server = ctx.createServer(Arrays.asList(auth, introspector));
0214 }
0215
0216 void createClient(String secret) throws Exception {
0217 createClient(secret, true);
0218 }
0219
0220 void createClient(String secret, boolean enableAes) throws Exception {
0221 TransportConf clientConf = enableAes ? conf
0222 : new TransportConf("rpc", MapConfigProvider.EMPTY);
0223 List<TransportClientBootstrap> bootstraps = Arrays.asList(
0224 new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret)));
0225 this.client = ctx.createClientFactory(bootstraps)
0226 .createClient(TestUtils.getLocalHost(), server.getPort());
0227 }
0228
0229 void close() {
0230 if (client != null) {
0231 client.close();
0232 }
0233 if (server != null) {
0234 server.close();
0235 }
0236 if (ctx != null) {
0237 ctx.close();
0238 }
0239 }
0240
0241 private SecretKeyHolder createKeyHolder(String secret) {
0242 SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
0243 when(keyHolder.getSaslUser(anyString())).thenReturn(appId);
0244 when(keyHolder.getSecretKey(anyString())).thenReturn(secret);
0245 return keyHolder;
0246 }
0247
0248 }
0249
0250 }