0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.crypto;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.security.GeneralSecurityException;
0023 import java.util.concurrent.TimeoutException;
0024
0025 import com.google.common.base.Throwables;
0026 import io.netty.buffer.ByteBuf;
0027 import io.netty.buffer.Unpooled;
0028 import io.netty.channel.Channel;
0029 import org.slf4j.Logger;
0030 import org.slf4j.LoggerFactory;
0031
0032 import org.apache.spark.network.client.TransportClient;
0033 import org.apache.spark.network.client.TransportClientBootstrap;
0034 import org.apache.spark.network.sasl.SaslClientBootstrap;
0035 import org.apache.spark.network.sasl.SecretKeyHolder;
0036 import org.apache.spark.network.util.TransportConf;
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048 public class AuthClientBootstrap implements TransportClientBootstrap {
0049
0050 private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class);
0051
0052 private final TransportConf conf;
0053 private final String appId;
0054 private final SecretKeyHolder secretKeyHolder;
0055
0056 public AuthClientBootstrap(
0057 TransportConf conf,
0058 String appId,
0059 SecretKeyHolder secretKeyHolder) {
0060 this.conf = conf;
0061
0062
0063
0064
0065
0066
0067 this.appId = appId;
0068 this.secretKeyHolder = secretKeyHolder;
0069 }
0070
0071 @Override
0072 public void doBootstrap(TransportClient client, Channel channel) {
0073 if (!conf.encryptionEnabled()) {
0074 LOG.debug("AES encryption disabled, using old auth protocol.");
0075 doSaslAuth(client, channel);
0076 return;
0077 }
0078
0079 try {
0080 doSparkAuth(client, channel);
0081 client.setClientId(appId);
0082 } catch (GeneralSecurityException | IOException e) {
0083 throw Throwables.propagate(e);
0084 } catch (RuntimeException e) {
0085
0086
0087
0088
0089 if (!conf.saslFallback() || e.getCause() instanceof TimeoutException) {
0090 throw e;
0091 }
0092
0093 if (LOG.isDebugEnabled()) {
0094 Throwable cause = e.getCause() != null ? e.getCause() : e;
0095 LOG.debug("New auth protocol failed, trying SASL.", cause);
0096 } else {
0097 LOG.info("New auth protocol failed, trying SASL.");
0098 }
0099 doSaslAuth(client, channel);
0100 }
0101 }
0102
0103 private void doSparkAuth(TransportClient client, Channel channel)
0104 throws GeneralSecurityException, IOException {
0105
0106 String secretKey = secretKeyHolder.getSecretKey(appId);
0107 try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) {
0108 ClientChallenge challenge = engine.challenge();
0109 ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength());
0110 challenge.encode(challengeData);
0111
0112 ByteBuffer responseData =
0113 client.sendRpcSync(challengeData.nioBuffer(), conf.authRTTimeoutMs());
0114 ServerResponse response = ServerResponse.decodeMessage(responseData);
0115
0116 engine.validate(response);
0117 engine.sessionCipher().addToChannel(channel);
0118 }
0119 }
0120
0121 private void doSaslAuth(TransportClient client, Channel channel) {
0122 SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder);
0123 sasl.doBootstrap(client, channel);
0124 }
0125
0126 }