0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.crypto;
0019
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.math.BigInteger;
0023 import java.security.GeneralSecurityException;
0024 import java.util.Arrays;
0025 import java.util.Properties;
0026 import javax.crypto.Cipher;
0027 import javax.crypto.SecretKey;
0028 import javax.crypto.SecretKeyFactory;
0029 import javax.crypto.ShortBufferException;
0030 import javax.crypto.spec.IvParameterSpec;
0031 import javax.crypto.spec.PBEKeySpec;
0032 import javax.crypto.spec.SecretKeySpec;
0033 import static java.nio.charset.StandardCharsets.UTF_8;
0034
0035 import com.google.common.annotations.VisibleForTesting;
0036 import com.google.common.base.Preconditions;
0037 import com.google.common.primitives.Bytes;
0038 import org.apache.commons.crypto.cipher.CryptoCipher;
0039 import org.apache.commons.crypto.cipher.CryptoCipherFactory;
0040 import org.apache.commons.crypto.random.CryptoRandom;
0041 import org.apache.commons.crypto.random.CryptoRandomFactory;
0042 import org.slf4j.Logger;
0043 import org.slf4j.LoggerFactory;
0044
0045 import org.apache.spark.network.util.TransportConf;
0046
0047
0048
0049
0050
0051 class AuthEngine implements Closeable {
0052
0053 private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class);
0054 private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 });
0055
0056 private final byte[] appId;
0057 private final char[] secret;
0058 private final TransportConf conf;
0059 private final Properties cryptoConf;
0060 private final CryptoRandom random;
0061
0062 private byte[] authNonce;
0063
0064 @VisibleForTesting
0065 byte[] challenge;
0066
0067 private TransportCipher sessionCipher;
0068 private CryptoCipher encryptor;
0069 private CryptoCipher decryptor;
0070
0071 AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException {
0072 this.appId = appId.getBytes(UTF_8);
0073 this.conf = conf;
0074 this.cryptoConf = conf.cryptoConf();
0075 this.secret = secret.toCharArray();
0076 this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf);
0077 }
0078
0079
0080
0081
0082
0083
0084 ClientChallenge challenge() throws GeneralSecurityException {
0085 this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
0086 SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
0087 authNonce, conf.encryptionKeyLength());
0088 initializeForAuth(conf.cipherTransformation(), authNonce, authKey);
0089
0090 this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
0091 return new ClientChallenge(new String(appId, UTF_8),
0092 conf.keyFactoryAlgorithm(),
0093 conf.keyFactoryIterations(),
0094 conf.cipherTransformation(),
0095 conf.encryptionKeyLength(),
0096 authNonce,
0097 challenge(appId, authNonce, challenge));
0098 }
0099
0100
0101
0102
0103
0104
0105
0106
0107 ServerResponse respond(ClientChallenge clientChallenge)
0108 throws GeneralSecurityException {
0109
0110 SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
0111 clientChallenge.nonce, clientChallenge.keyLength);
0112 initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey);
0113
0114 byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge);
0115 byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge));
0116 byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE);
0117 byte[] inputIv = randomBytes(conf.ivLength());
0118 byte[] outputIv = randomBytes(conf.ivLength());
0119
0120 SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations,
0121 sessionNonce, clientChallenge.keyLength);
0122 this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey,
0123 inputIv, outputIv);
0124
0125
0126 return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv));
0127 }
0128
0129
0130
0131
0132
0133
0134 void validate(ServerResponse serverResponse) throws GeneralSecurityException {
0135 byte[] response = validateChallenge(authNonce, serverResponse.response);
0136
0137 byte[] expected = rawResponse(challenge);
0138 Preconditions.checkArgument(Arrays.equals(expected, response));
0139
0140 byte[] nonce = decrypt(serverResponse.nonce);
0141 byte[] inputIv = decrypt(serverResponse.inputIv);
0142 byte[] outputIv = decrypt(serverResponse.outputIv);
0143
0144 SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(),
0145 nonce, conf.encryptionKeyLength());
0146 this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey,
0147 inputIv, outputIv);
0148 }
0149
0150 TransportCipher sessionCipher() {
0151 Preconditions.checkState(sessionCipher != null);
0152 return sessionCipher;
0153 }
0154
0155 @Override
0156 public void close() throws IOException {
0157
0158
0159
0160 RuntimeException error = null;
0161 byte[] dummy = new byte[8];
0162 if (encryptor != null) {
0163 try {
0164 doCipherOp(Cipher.ENCRYPT_MODE, dummy, true);
0165 } catch (Exception e) {
0166 error = new RuntimeException(e);
0167 }
0168 encryptor = null;
0169 }
0170 if (decryptor != null) {
0171 try {
0172 doCipherOp(Cipher.DECRYPT_MODE, dummy, true);
0173 } catch (Exception e) {
0174 error = new RuntimeException(e);
0175 }
0176 decryptor = null;
0177 }
0178 random.close();
0179
0180 if (error != null) {
0181 throw error;
0182 }
0183 }
0184
0185 @VisibleForTesting
0186 byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException {
0187 return encrypt(Bytes.concat(appId, nonce, challenge));
0188 }
0189
0190 @VisibleForTesting
0191 byte[] rawResponse(byte[] challenge) {
0192 BigInteger orig = new BigInteger(challenge);
0193 BigInteger response = orig.add(ONE);
0194 return response.toByteArray();
0195 }
0196
0197 private byte[] decrypt(byte[] in) throws GeneralSecurityException {
0198 return doCipherOp(Cipher.DECRYPT_MODE, in, false);
0199 }
0200
0201 private byte[] encrypt(byte[] in) throws GeneralSecurityException {
0202 return doCipherOp(Cipher.ENCRYPT_MODE, in, false);
0203 }
0204
0205 private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key)
0206 throws GeneralSecurityException {
0207
0208
0209
0210
0211 byte[] iv = new byte[conf.ivLength()];
0212 System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
0213
0214 CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
0215 _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
0216 this.encryptor = _encryptor;
0217
0218 CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
0219 _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
0220 this.decryptor = _decryptor;
0221 }
0222
0223
0224
0225
0226
0227 private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge)
0228 throws GeneralSecurityException {
0229
0230 byte[] challenge = decrypt(encryptedChallenge);
0231 checkSubArray(appId, challenge, 0);
0232 checkSubArray(nonce, challenge, appId.length);
0233 return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length);
0234 }
0235
0236 private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength)
0237 throws GeneralSecurityException {
0238
0239 SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf);
0240 PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength);
0241
0242 long start = System.nanoTime();
0243 SecretKey key = factory.generateSecret(spec);
0244 long end = System.nanoTime();
0245
0246 LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(),
0247 (end - start) / 1000);
0248
0249 return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
0250 }
0251
0252 private byte[] doCipherOp(int mode, byte[] in, boolean isFinal)
0253 throws GeneralSecurityException {
0254
0255 CryptoCipher cipher;
0256 switch (mode) {
0257 case Cipher.ENCRYPT_MODE:
0258 cipher = encryptor;
0259 break;
0260 case Cipher.DECRYPT_MODE:
0261 cipher = decryptor;
0262 break;
0263 default:
0264 throw new IllegalArgumentException(String.valueOf(mode));
0265 }
0266
0267 Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error.");
0268
0269 try {
0270 int scale = 1;
0271 while (true) {
0272 int size = in.length * scale;
0273 byte[] buffer = new byte[size];
0274 try {
0275 int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
0276 : cipher.update(in, 0, in.length, buffer, 0);
0277 if (outSize != buffer.length) {
0278 byte[] output = new byte[outSize];
0279 System.arraycopy(buffer, 0, output, 0, output.length);
0280 return output;
0281 } else {
0282 return buffer;
0283 }
0284 } catch (ShortBufferException e) {
0285
0286 scale *= 2;
0287 }
0288 }
0289 } catch (InternalError ie) {
0290
0291
0292 if (mode == Cipher.ENCRYPT_MODE) {
0293 this.encryptor = null;
0294 } else {
0295 this.decryptor = null;
0296 }
0297 throw ie;
0298 }
0299 }
0300
0301 private byte[] randomBytes(int count) {
0302 byte[] bytes = new byte[count];
0303 random.nextBytes(bytes);
0304 return bytes;
0305 }
0306
0307
0308 private void checkSubArray(byte[] test, byte[] data, int offset) {
0309 Preconditions.checkArgument(data.length >= test.length + offset);
0310 for (int i = 0; i < test.length; i++) {
0311 Preconditions.checkArgument(test[i] == data[i + offset]);
0312 }
0313 }
0314
0315 }