0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.crypto;
0019
0020 import java.nio.ByteBuffer;
0021
0022 import com.google.common.annotations.VisibleForTesting;
0023 import com.google.common.base.Preconditions;
0024 import com.google.common.base.Throwables;
0025 import io.netty.buffer.ByteBuf;
0026 import io.netty.buffer.Unpooled;
0027 import io.netty.channel.Channel;
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030
0031 import org.apache.spark.network.client.RpcResponseCallback;
0032 import org.apache.spark.network.client.TransportClient;
0033 import org.apache.spark.network.sasl.SecretKeyHolder;
0034 import org.apache.spark.network.sasl.SaslRpcHandler;
0035 import org.apache.spark.network.server.AbstractAuthRpcHandler;
0036 import org.apache.spark.network.server.RpcHandler;
0037 import org.apache.spark.network.util.TransportConf;
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048 class AuthRpcHandler extends AbstractAuthRpcHandler {
0049 private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
0050
0051
0052 private final TransportConf conf;
0053
0054
0055 private final Channel channel;
0056
0057
0058 private final SecretKeyHolder secretKeyHolder;
0059
0060
0061 @VisibleForTesting
0062 SaslRpcHandler saslHandler;
0063
0064 AuthRpcHandler(
0065 TransportConf conf,
0066 Channel channel,
0067 RpcHandler delegate,
0068 SecretKeyHolder secretKeyHolder) {
0069 super(delegate);
0070 this.conf = conf;
0071 this.channel = channel;
0072 this.secretKeyHolder = secretKeyHolder;
0073 }
0074
0075 @Override
0076 protected boolean doAuthChallenge(
0077 TransportClient client,
0078 ByteBuffer message,
0079 RpcResponseCallback callback) {
0080 if (saslHandler != null) {
0081 return saslHandler.doAuthChallenge(client, message, callback);
0082 }
0083
0084 int position = message.position();
0085 int limit = message.limit();
0086
0087 ClientChallenge challenge;
0088 try {
0089 challenge = ClientChallenge.decodeMessage(message);
0090 LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress());
0091 } catch (RuntimeException e) {
0092 if (conf.saslFallback()) {
0093 LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
0094 channel.remoteAddress());
0095 saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
0096 message.position(position);
0097 message.limit(limit);
0098 return saslHandler.doAuthChallenge(client, message, callback);
0099 } else {
0100 LOG.debug("Unexpected challenge message from client {}, closing channel.",
0101 channel.remoteAddress());
0102 callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
0103 channel.close();
0104 }
0105 return false;
0106 }
0107
0108
0109 AuthEngine engine = null;
0110 try {
0111 String secret = secretKeyHolder.getSecretKey(challenge.appId);
0112 Preconditions.checkState(secret != null,
0113 "Trying to authenticate non-registered app %s.", challenge.appId);
0114 LOG.debug("Authenticating challenge for app {}.", challenge.appId);
0115 engine = new AuthEngine(challenge.appId, secret, conf);
0116 ServerResponse response = engine.respond(challenge);
0117 ByteBuf responseData = Unpooled.buffer(response.encodedLength());
0118 response.encode(responseData);
0119 callback.onSuccess(responseData.nioBuffer());
0120 engine.sessionCipher().addToChannel(channel);
0121 client.setClientId(challenge.appId);
0122 } catch (Exception e) {
0123
0124 LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
0125 callback.onFailure(new IllegalArgumentException("Authentication failed."));
0126 channel.close();
0127 return false;
0128 } finally {
0129 if (engine != null) {
0130 try {
0131 engine.close();
0132 } catch (Exception e) {
0133 throw Throwables.propagate(e);
0134 }
0135 }
0136 }
0137
0138 LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
0139 return true;
0140 }
0141 }