0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.crypto;
0019
0020 import java.nio.ByteBuffer;
0021 import java.nio.channels.WritableByteChannel;
0022 import java.util.Arrays;
0023 import java.util.Map;
0024 import java.security.InvalidKeyException;
0025 import java.util.Random;
0026
0027 import static java.nio.charset.StandardCharsets.UTF_8;
0028
0029 import com.google.common.collect.ImmutableMap;
0030 import io.netty.buffer.ByteBuf;
0031 import io.netty.buffer.Unpooled;
0032 import io.netty.channel.FileRegion;
0033 import org.junit.BeforeClass;
0034 import org.junit.Test;
0035 import org.mockito.invocation.InvocationOnMock;
0036 import org.mockito.stubbing.Answer;
0037 import static org.junit.Assert.*;
0038 import static org.mockito.Mockito.*;
0039
0040 import org.apache.spark.network.util.ByteArrayWritableChannel;
0041 import org.apache.spark.network.util.MapConfigProvider;
0042 import org.apache.spark.network.util.TransportConf;
0043
0044 public class AuthEngineSuite {
0045
0046 private static TransportConf conf;
0047
0048 @BeforeClass
0049 public static void setUp() {
0050 conf = new TransportConf("rpc", MapConfigProvider.EMPTY);
0051 }
0052
0053 @Test
0054 public void testAuthEngine() throws Exception {
0055 AuthEngine client = new AuthEngine("appId", "secret", conf);
0056 AuthEngine server = new AuthEngine("appId", "secret", conf);
0057
0058 try {
0059 ClientChallenge clientChallenge = client.challenge();
0060 ServerResponse serverResponse = server.respond(clientChallenge);
0061 client.validate(serverResponse);
0062
0063 TransportCipher serverCipher = server.sessionCipher();
0064 TransportCipher clientCipher = client.sessionCipher();
0065
0066 assertTrue(Arrays.equals(serverCipher.getInputIv(), clientCipher.getOutputIv()));
0067 assertTrue(Arrays.equals(serverCipher.getOutputIv(), clientCipher.getInputIv()));
0068 assertEquals(serverCipher.getKey(), clientCipher.getKey());
0069 } finally {
0070 client.close();
0071 server.close();
0072 }
0073 }
0074
0075 @Test
0076 public void testMismatchedSecret() throws Exception {
0077 AuthEngine client = new AuthEngine("appId", "secret", conf);
0078 AuthEngine server = new AuthEngine("appId", "different_secret", conf);
0079
0080 ClientChallenge clientChallenge = client.challenge();
0081 try {
0082 server.respond(clientChallenge);
0083 fail("Should have failed to validate response.");
0084 } catch (IllegalArgumentException e) {
0085
0086 }
0087 }
0088
0089 @Test(expected = IllegalArgumentException.class)
0090 public void testWrongAppId() throws Exception {
0091 AuthEngine engine = new AuthEngine("appId", "secret", conf);
0092 ClientChallenge challenge = engine.challenge();
0093
0094 byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce,
0095 engine.rawResponse(engine.challenge));
0096 engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
0097 challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
0098 }
0099
0100 @Test(expected = IllegalArgumentException.class)
0101 public void testWrongNonce() throws Exception {
0102 AuthEngine engine = new AuthEngine("appId", "secret", conf);
0103 ClientChallenge challenge = engine.challenge();
0104
0105 byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 },
0106 engine.rawResponse(engine.challenge));
0107 engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
0108 challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
0109 }
0110
0111 @Test(expected = IllegalArgumentException.class)
0112 public void testBadChallenge() throws Exception {
0113 AuthEngine engine = new AuthEngine("appId", "secret", conf);
0114 ClientChallenge challenge = engine.challenge();
0115
0116 byte[] badChallenge = new byte[challenge.challenge.length];
0117 engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations,
0118 challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
0119 }
0120
0121 @Test(expected = InvalidKeyException.class)
0122 public void testBadKeySize() throws Exception {
0123 Map<String, String> mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42");
0124 TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf));
0125
0126 try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) {
0127 engine.challenge();
0128 fail("Should have failed to create challenge message.");
0129
0130
0131 engine.close();
0132 }
0133 }
0134
0135 @Test
0136 public void testEncryptedMessage() throws Exception {
0137 AuthEngine client = new AuthEngine("appId", "secret", conf);
0138 AuthEngine server = new AuthEngine("appId", "secret", conf);
0139 try {
0140 ClientChallenge clientChallenge = client.challenge();
0141 ServerResponse serverResponse = server.respond(clientChallenge);
0142 client.validate(serverResponse);
0143
0144 TransportCipher cipher = server.sessionCipher();
0145 TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
0146
0147 byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1];
0148 new Random().nextBytes(data);
0149 ByteBuf buf = Unpooled.wrappedBuffer(data);
0150
0151 ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
0152 TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf);
0153 while (emsg.transfered() < emsg.count()) {
0154 emsg.transferTo(channel, emsg.transfered());
0155 }
0156 assertEquals(data.length, channel.length());
0157 } finally {
0158 client.close();
0159 server.close();
0160 }
0161 }
0162
0163 @Test
0164 public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception {
0165 AuthEngine client = new AuthEngine("appId", "secret", conf);
0166 AuthEngine server = new AuthEngine("appId", "secret", conf);
0167 try {
0168 ClientChallenge clientChallenge = client.challenge();
0169 ServerResponse serverResponse = server.respond(clientChallenge);
0170 client.validate(serverResponse);
0171
0172 TransportCipher cipher = server.sessionCipher();
0173 TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
0174
0175 int testDataLength = 4;
0176 FileRegion region = mock(FileRegion.class);
0177 when(region.count()).thenReturn((long) testDataLength);
0178
0179 when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() {
0180
0181 private boolean firstTime = true;
0182
0183 @Override
0184 public Long answer(InvocationOnMock invocationOnMock) throws Throwable {
0185 if (firstTime) {
0186 firstTime = false;
0187 return 0L;
0188 } else {
0189 WritableByteChannel channel = invocationOnMock.getArgument(0);
0190 channel.write(ByteBuffer.wrap(new byte[testDataLength]));
0191 return (long) testDataLength;
0192 }
0193 }
0194 });
0195
0196 TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region);
0197 ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength);
0198
0199 assertEquals(0L, emsg.transferTo(channel, emsg.transfered()));
0200 assertEquals(testDataLength, emsg.transferTo(channel, emsg.transfered()));
0201 assertEquals(emsg.transfered(), emsg.count());
0202 assertEquals(4, channel.length());
0203 } finally {
0204 client.close();
0205 server.close();
0206 }
0207 }
0208 }