0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.sasl;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.util.ArrayList;
0023 import java.util.Arrays;
0024
0025 import org.junit.After;
0026 import org.junit.AfterClass;
0027 import org.junit.BeforeClass;
0028 import org.junit.Test;
0029
0030 import static org.junit.Assert.*;
0031 import static org.mockito.Mockito.*;
0032
0033 import org.apache.spark.network.TestUtils;
0034 import org.apache.spark.network.TransportContext;
0035 import org.apache.spark.network.client.RpcResponseCallback;
0036 import org.apache.spark.network.client.TransportClient;
0037 import org.apache.spark.network.client.TransportClientFactory;
0038 import org.apache.spark.network.server.OneForOneStreamManager;
0039 import org.apache.spark.network.server.RpcHandler;
0040 import org.apache.spark.network.server.StreamManager;
0041 import org.apache.spark.network.server.TransportServer;
0042 import org.apache.spark.network.server.TransportServerBootstrap;
0043 import org.apache.spark.network.util.JavaUtils;
0044 import org.apache.spark.network.util.MapConfigProvider;
0045 import org.apache.spark.network.util.TransportConf;
0046
0047 public class SaslIntegrationSuite {
0048
0049
0050
0051 private static final long TIMEOUT_MS = 10_000;
0052
0053 static TransportServer server;
0054 static TransportConf conf;
0055 static TransportContext context;
0056 static SecretKeyHolder secretKeyHolder;
0057
0058 TransportClientFactory clientFactory;
0059
0060 @BeforeClass
0061 public static void beforeAll() throws IOException {
0062 conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0063 context = new TransportContext(conf, new TestRpcHandler());
0064
0065 secretKeyHolder = mock(SecretKeyHolder.class);
0066 when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
0067 when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
0068 when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
0069 when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
0070 when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
0071 when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
0072
0073 TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
0074 server = context.createServer(Arrays.asList(bootstrap));
0075 }
0076
0077
0078 @AfterClass
0079 public static void afterAll() {
0080 server.close();
0081 context.close();
0082 }
0083
0084 @After
0085 public void afterEach() {
0086 if (clientFactory != null) {
0087 clientFactory.close();
0088 clientFactory = null;
0089 }
0090 }
0091
0092 @Test
0093 public void testGoodClient() throws IOException, InterruptedException {
0094 clientFactory = context.createClientFactory(
0095 Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
0096
0097 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0098 String msg = "Hello, World!";
0099 ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS);
0100 assertEquals(msg, JavaUtils.bytesToString(resp));
0101 }
0102
0103 @Test
0104 public void testBadClient() {
0105 SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
0106 when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
0107 when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
0108 clientFactory = context.createClientFactory(
0109 Arrays.asList(new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
0110
0111 try {
0112
0113 clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0114 fail("Connection should have failed.");
0115 } catch (Exception e) {
0116 assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
0117 }
0118 }
0119
0120 @Test
0121 public void testNoSaslClient() throws IOException, InterruptedException {
0122 clientFactory = context.createClientFactory(new ArrayList<>());
0123
0124 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0125 try {
0126 client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
0127 fail("Should have failed");
0128 } catch (Exception e) {
0129 assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
0130 }
0131
0132 try {
0133
0134 client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
0135 fail("Should have failed");
0136 } catch (Exception e) {
0137 assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
0138 }
0139 }
0140
0141 @Test
0142 public void testNoSaslServer() {
0143 RpcHandler handler = new TestRpcHandler();
0144 try (TransportContext context = new TransportContext(conf, handler)) {
0145 clientFactory = context.createClientFactory(
0146 Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
0147 try (TransportServer server = context.createServer()) {
0148 clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0149 } catch (Exception e) {
0150 assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
0151 }
0152 }
0153 }
0154
0155
0156 public static class TestRpcHandler extends RpcHandler {
0157 @Override
0158 public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
0159 callback.onSuccess(message);
0160 }
0161
0162 @Override
0163 public StreamManager getStreamManager() {
0164 return new OneForOneStreamManager();
0165 }
0166 }
0167 }