Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.sasl;
0019 
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.util.ArrayList;
0023 import java.util.Arrays;
0024 
0025 import org.junit.After;
0026 import org.junit.AfterClass;
0027 import org.junit.BeforeClass;
0028 import org.junit.Test;
0029 
0030 import static org.junit.Assert.*;
0031 import static org.mockito.Mockito.*;
0032 
0033 import org.apache.spark.network.TestUtils;
0034 import org.apache.spark.network.TransportContext;
0035 import org.apache.spark.network.client.RpcResponseCallback;
0036 import org.apache.spark.network.client.TransportClient;
0037 import org.apache.spark.network.client.TransportClientFactory;
0038 import org.apache.spark.network.server.OneForOneStreamManager;
0039 import org.apache.spark.network.server.RpcHandler;
0040 import org.apache.spark.network.server.StreamManager;
0041 import org.apache.spark.network.server.TransportServer;
0042 import org.apache.spark.network.server.TransportServerBootstrap;
0043 import org.apache.spark.network.util.JavaUtils;
0044 import org.apache.spark.network.util.MapConfigProvider;
0045 import org.apache.spark.network.util.TransportConf;
0046 
0047 public class SaslIntegrationSuite {
0048 
0049   // Use a long timeout to account for slow / overloaded build machines. In the normal case,
0050   // tests should finish way before the timeout expires.
0051   private static final long TIMEOUT_MS = 10_000;
0052 
0053   static TransportServer server;
0054   static TransportConf conf;
0055   static TransportContext context;
0056   static SecretKeyHolder secretKeyHolder;
0057 
0058   TransportClientFactory clientFactory;
0059 
0060   @BeforeClass
0061   public static void beforeAll() throws IOException {
0062     conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0063     context = new TransportContext(conf, new TestRpcHandler());
0064 
0065     secretKeyHolder = mock(SecretKeyHolder.class);
0066     when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
0067     when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
0068     when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
0069     when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
0070     when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
0071     when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
0072 
0073     TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
0074     server = context.createServer(Arrays.asList(bootstrap));
0075   }
0076 
0077 
0078   @AfterClass
0079   public static void afterAll() {
0080     server.close();
0081     context.close();
0082   }
0083 
0084   @After
0085   public void afterEach() {
0086     if (clientFactory != null) {
0087       clientFactory.close();
0088       clientFactory = null;
0089     }
0090   }
0091 
0092   @Test
0093   public void testGoodClient() throws IOException, InterruptedException {
0094     clientFactory = context.createClientFactory(
0095         Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
0096 
0097     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0098     String msg = "Hello, World!";
0099     ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS);
0100     assertEquals(msg, JavaUtils.bytesToString(resp));
0101   }
0102 
0103   @Test
0104   public void testBadClient() {
0105     SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
0106     when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
0107     when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
0108     clientFactory = context.createClientFactory(
0109         Arrays.asList(new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
0110 
0111     try {
0112       // Bootstrap should fail on startup.
0113       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0114       fail("Connection should have failed.");
0115     } catch (Exception e) {
0116       assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
0117     }
0118   }
0119 
0120   @Test
0121   public void testNoSaslClient() throws IOException, InterruptedException {
0122     clientFactory = context.createClientFactory(new ArrayList<>());
0123 
0124     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0125     try {
0126       client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
0127       fail("Should have failed");
0128     } catch (Exception e) {
0129       assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
0130     }
0131 
0132     try {
0133       // Guessing the right tag byte doesn't magically get you in...
0134       client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
0135       fail("Should have failed");
0136     } catch (Exception e) {
0137       assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
0138     }
0139   }
0140 
0141   @Test
0142   public void testNoSaslServer() {
0143     RpcHandler handler = new TestRpcHandler();
0144     try (TransportContext context = new TransportContext(conf, handler)) {
0145       clientFactory = context.createClientFactory(
0146           Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
0147       try (TransportServer server = context.createServer()) {
0148         clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0149       } catch (Exception e) {
0150         assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
0151       }
0152     }
0153   }
0154 
0155   /** RPC handler which simply responds with the message it received. */
0156   public static class TestRpcHandler extends RpcHandler {
0157     @Override
0158     public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
0159       callback.onSuccess(message);
0160     }
0161 
0162     @Override
0163     public StreamManager getStreamManager() {
0164       return new OneForOneStreamManager();
0165     }
0166   }
0167 }