Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.crypto;
0019 
0020 import java.nio.ByteBuffer;
0021 import java.nio.channels.WritableByteChannel;
0022 import java.util.Arrays;
0023 import java.util.Map;
0024 import java.security.InvalidKeyException;
0025 import java.util.Random;
0026 
0027 import static java.nio.charset.StandardCharsets.UTF_8;
0028 
0029 import com.google.common.collect.ImmutableMap;
0030 import io.netty.buffer.ByteBuf;
0031 import io.netty.buffer.Unpooled;
0032 import io.netty.channel.FileRegion;
0033 import org.junit.BeforeClass;
0034 import org.junit.Test;
0035 import org.mockito.invocation.InvocationOnMock;
0036 import org.mockito.stubbing.Answer;
0037 import static org.junit.Assert.*;
0038 import static org.mockito.Mockito.*;
0039 
0040 import org.apache.spark.network.util.ByteArrayWritableChannel;
0041 import org.apache.spark.network.util.MapConfigProvider;
0042 import org.apache.spark.network.util.TransportConf;
0043 
0044 public class AuthEngineSuite {
0045 
0046   private static TransportConf conf;
0047 
0048   @BeforeClass
0049   public static void setUp() {
0050     conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
0051   }
0052 
0053   @Test
0054   public void testAuthEngine() throws Exception {
0055     AuthEngine client = new AuthEngine("appId", "secret", conf);
0056     AuthEngine server = new AuthEngine("appId", "secret", conf);
0057 
0058     try {
0059       ClientChallenge clientChallenge = client.challenge();
0060       ServerResponse serverResponse = server.respond(clientChallenge);
0061       client.validate(serverResponse);
0062 
0063       TransportCipher serverCipher = server.sessionCipher();
0064       TransportCipher clientCipher = client.sessionCipher();
0065 
0066       assertTrue(Arrays.equals(serverCipher.getInputIv(), clientCipher.getOutputIv()));
0067       assertTrue(Arrays.equals(serverCipher.getOutputIv(), clientCipher.getInputIv()));
0068       assertEquals(serverCipher.getKey(), clientCipher.getKey());
0069     } finally {
0070       client.close();
0071       server.close();
0072     }
0073   }
0074 
0075   @Test
0076   public void testMismatchedSecret() throws Exception {
0077     AuthEngine client = new AuthEngine("appId", "secret", conf);
0078     AuthEngine server = new AuthEngine("appId", "different_secret", conf);
0079 
0080     ClientChallenge clientChallenge = client.challenge();
0081     try {
0082       server.respond(clientChallenge);
0083       fail("Should have failed to validate response.");
0084     } catch (IllegalArgumentException e) {
0085       // Expected.
0086     }
0087   }
0088 
0089   @Test(expected = IllegalArgumentException.class)
0090   public void testWrongAppId() throws Exception {
0091     AuthEngine engine = new AuthEngine("appId", "secret", conf);
0092     ClientChallenge challenge = engine.challenge();
0093 
0094     byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce,
0095       engine.rawResponse(engine.challenge));
0096     engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
0097       challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
0098   }
0099 
0100   @Test(expected = IllegalArgumentException.class)
0101   public void testWrongNonce() throws Exception {
0102     AuthEngine engine = new AuthEngine("appId", "secret", conf);
0103     ClientChallenge challenge = engine.challenge();
0104 
0105     byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 },
0106       engine.rawResponse(engine.challenge));
0107     engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
0108       challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
0109   }
0110 
0111   @Test(expected = IllegalArgumentException.class)
0112   public void testBadChallenge() throws Exception {
0113     AuthEngine engine = new AuthEngine("appId", "secret", conf);
0114     ClientChallenge challenge = engine.challenge();
0115 
0116     byte[] badChallenge = new byte[challenge.challenge.length];
0117     engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
0118       challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
0119   }
0120 
0121   @Test(expected = InvalidKeyException.class)
0122   public void testBadKeySize() throws Exception {
0123     Map<String, String> mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42");
0124     TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf));
0125 
0126     try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) {
0127       engine.challenge();
0128       fail("Should have failed to create challenge message.");
0129 
0130       // Call close explicitly to make sure it's idempotent.
0131       engine.close();
0132     }
0133   }
0134 
0135   @Test
0136   public void testEncryptedMessage() throws Exception {
0137     AuthEngine client = new AuthEngine("appId", "secret", conf);
0138     AuthEngine server = new AuthEngine("appId", "secret", conf);
0139     try {
0140       ClientChallenge clientChallenge = client.challenge();
0141       ServerResponse serverResponse = server.respond(clientChallenge);
0142       client.validate(serverResponse);
0143 
0144       TransportCipher cipher = server.sessionCipher();
0145       TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
0146 
0147       byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1];
0148       new Random().nextBytes(data);
0149       ByteBuf buf = Unpooled.wrappedBuffer(data);
0150 
0151       ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
0152       TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf);
0153       while (emsg.transfered() < emsg.count()) {
0154         emsg.transferTo(channel, emsg.transfered());
0155       }
0156       assertEquals(data.length, channel.length());
0157     } finally {
0158       client.close();
0159       server.close();
0160     }
0161   }
0162 
0163   @Test
0164   public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception {
0165     AuthEngine client = new AuthEngine("appId", "secret", conf);
0166     AuthEngine server = new AuthEngine("appId", "secret", conf);
0167     try {
0168       ClientChallenge clientChallenge = client.challenge();
0169       ServerResponse serverResponse = server.respond(clientChallenge);
0170       client.validate(serverResponse);
0171 
0172       TransportCipher cipher = server.sessionCipher();
0173       TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
0174 
0175       int testDataLength = 4;
0176       FileRegion region = mock(FileRegion.class);
0177       when(region.count()).thenReturn((long) testDataLength);
0178       // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one.
0179       when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() {
0180 
0181         private boolean firstTime = true;
0182 
0183         @Override
0184         public Long answer(InvocationOnMock invocationOnMock) throws Throwable {
0185           if (firstTime) {
0186             firstTime = false;
0187             return 0L;
0188           } else {
0189             WritableByteChannel channel = invocationOnMock.getArgument(0);
0190             channel.write(ByteBuffer.wrap(new byte[testDataLength]));
0191             return (long) testDataLength;
0192           }
0193         }
0194       });
0195 
0196       TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region);
0197       ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength);
0198       // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes.
0199       assertEquals(0L, emsg.transferTo(channel, emsg.transfered()));
0200       assertEquals(testDataLength, emsg.transferTo(channel, emsg.transfered()));
0201       assertEquals(emsg.transfered(), emsg.count());
0202       assertEquals(4, channel.length());
0203     } finally {
0204       client.close();
0205       server.close();
0206     }
0207   }
0208 }