0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.sasl;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import javax.security.sasl.Sasl;
0023
0024 import io.netty.buffer.ByteBuf;
0025 import io.netty.buffer.Unpooled;
0026 import io.netty.channel.Channel;
0027 import org.slf4j.Logger;
0028 import org.slf4j.LoggerFactory;
0029
0030 import org.apache.spark.network.client.RpcResponseCallback;
0031 import org.apache.spark.network.client.TransportClient;
0032 import org.apache.spark.network.server.AbstractAuthRpcHandler;
0033 import org.apache.spark.network.server.RpcHandler;
0034 import org.apache.spark.network.util.JavaUtils;
0035 import org.apache.spark.network.util.TransportConf;
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045 public class SaslRpcHandler extends AbstractAuthRpcHandler {
0046 private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
0047
0048
0049 private final TransportConf conf;
0050
0051
0052 private final Channel channel;
0053
0054
0055 private final SecretKeyHolder secretKeyHolder;
0056
0057 private SparkSaslServer saslServer;
0058
0059 public SaslRpcHandler(
0060 TransportConf conf,
0061 Channel channel,
0062 RpcHandler delegate,
0063 SecretKeyHolder secretKeyHolder) {
0064 super(delegate);
0065 this.conf = conf;
0066 this.channel = channel;
0067 this.secretKeyHolder = secretKeyHolder;
0068 this.saslServer = null;
0069 }
0070
0071 @Override
0072 public boolean doAuthChallenge(
0073 TransportClient client,
0074 ByteBuffer message,
0075 RpcResponseCallback callback) {
0076 if (saslServer == null || !saslServer.isComplete()) {
0077 ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
0078 SaslMessage saslMessage;
0079 try {
0080 saslMessage = SaslMessage.decode(nettyBuf);
0081 } finally {
0082 nettyBuf.release();
0083 }
0084
0085 if (saslServer == null) {
0086
0087 client.setClientId(saslMessage.appId);
0088 saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
0089 conf.saslServerAlwaysEncrypt());
0090 }
0091
0092 byte[] response;
0093 try {
0094 response = saslServer.response(JavaUtils.bufferToArray(
0095 saslMessage.body().nioByteBuffer()));
0096 } catch (IOException ioe) {
0097 throw new RuntimeException(ioe);
0098 }
0099 callback.onSuccess(ByteBuffer.wrap(response));
0100 }
0101
0102
0103
0104
0105
0106
0107 if (saslServer.isComplete()) {
0108 if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
0109 logger.debug("SASL authentication successful for channel {}", client);
0110 complete(true);
0111 return true;
0112 }
0113
0114 logger.debug("Enabling encryption for channel {}", client);
0115 SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
0116 complete(false);
0117 return true;
0118 }
0119 return false;
0120 }
0121
0122 @Override
0123 public void channelInactive(TransportClient client) {
0124 try {
0125 super.channelInactive(client);
0126 } finally {
0127 if (saslServer != null) {
0128 saslServer.dispose();
0129 }
0130 }
0131 }
0132
0133 private void complete(boolean dispose) {
0134 if (dispose) {
0135 try {
0136 saslServer.dispose();
0137 } catch (RuntimeException e) {
0138 logger.error("Error while disposing SASL server", e);
0139 }
0140 }
0141
0142 saslServer = null;
0143 }
0144
0145 }