0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.sasl;
0019
0020 import javax.security.auth.callback.Callback;
0021 import javax.security.auth.callback.CallbackHandler;
0022 import javax.security.auth.callback.NameCallback;
0023 import javax.security.auth.callback.PasswordCallback;
0024 import javax.security.auth.callback.UnsupportedCallbackException;
0025 import javax.security.sasl.AuthorizeCallback;
0026 import javax.security.sasl.RealmCallback;
0027 import javax.security.sasl.Sasl;
0028 import javax.security.sasl.SaslException;
0029 import javax.security.sasl.SaslServer;
0030 import java.nio.charset.StandardCharsets;
0031 import java.util.Map;
0032
0033 import com.google.common.base.Preconditions;
0034 import com.google.common.base.Throwables;
0035 import com.google.common.collect.ImmutableMap;
0036 import io.netty.buffer.ByteBuf;
0037 import io.netty.buffer.Unpooled;
0038 import io.netty.handler.codec.base64.Base64;
0039 import org.slf4j.Logger;
0040 import org.slf4j.LoggerFactory;
0041
0042
0043
0044
0045
0046
0047 public class SparkSaslServer implements SaslEncryptionBackend {
0048 private static final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
0049
0050
0051
0052
0053
0054 static final String DEFAULT_REALM = "default";
0055
0056
0057
0058
0059
0060 static final String DIGEST = "DIGEST-MD5";
0061
0062
0063
0064
0065 static final String QOP_AUTH_CONF = "auth-conf";
0066
0067
0068
0069
0070 static final String QOP_AUTH = "auth";
0071
0072
0073 private final String secretKeyId;
0074 private final SecretKeyHolder secretKeyHolder;
0075 private SaslServer saslServer;
0076
0077 public SparkSaslServer(
0078 String secretKeyId,
0079 SecretKeyHolder secretKeyHolder,
0080 boolean alwaysEncrypt) {
0081 this.secretKeyId = secretKeyId;
0082 this.secretKeyHolder = secretKeyHolder;
0083
0084
0085
0086
0087 String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
0088 Map<String, String> saslProps = ImmutableMap.<String, String>builder()
0089 .put(Sasl.SERVER_AUTH, "true")
0090 .put(Sasl.QOP, qop)
0091 .build();
0092 try {
0093 this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
0094 new DigestCallbackHandler());
0095 } catch (SaslException e) {
0096 throw Throwables.propagate(e);
0097 }
0098 }
0099
0100
0101
0102
0103 public synchronized boolean isComplete() {
0104 return saslServer != null && saslServer.isComplete();
0105 }
0106
0107
0108 public Object getNegotiatedProperty(String name) {
0109 return saslServer.getNegotiatedProperty(name);
0110 }
0111
0112
0113
0114
0115
0116
0117 public synchronized byte[] response(byte[] token) {
0118 try {
0119 return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
0120 } catch (SaslException e) {
0121 throw Throwables.propagate(e);
0122 }
0123 }
0124
0125
0126
0127
0128
0129 @Override
0130 public synchronized void dispose() {
0131 if (saslServer != null) {
0132 try {
0133 saslServer.dispose();
0134 } catch (SaslException e) {
0135
0136 } finally {
0137 saslServer = null;
0138 }
0139 }
0140 }
0141
0142 @Override
0143 public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
0144 return saslServer.wrap(data, offset, len);
0145 }
0146
0147 @Override
0148 public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
0149 return saslServer.unwrap(data, offset, len);
0150 }
0151
0152
0153
0154
0155 private class DigestCallbackHandler implements CallbackHandler {
0156 @Override
0157 public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
0158 for (Callback callback : callbacks) {
0159 if (callback instanceof NameCallback) {
0160 logger.trace("SASL server callback: setting username");
0161 NameCallback nc = (NameCallback) callback;
0162 nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
0163 } else if (callback instanceof PasswordCallback) {
0164 logger.trace("SASL server callback: setting password");
0165 PasswordCallback pc = (PasswordCallback) callback;
0166 pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
0167 } else if (callback instanceof RealmCallback) {
0168 logger.trace("SASL server callback: setting realm");
0169 RealmCallback rc = (RealmCallback) callback;
0170 rc.setText(rc.getDefaultText());
0171 } else if (callback instanceof AuthorizeCallback) {
0172 AuthorizeCallback ac = (AuthorizeCallback) callback;
0173 String authId = ac.getAuthenticationID();
0174 String authzId = ac.getAuthorizationID();
0175 ac.setAuthorized(authId.equals(authzId));
0176 if (ac.isAuthorized()) {
0177 ac.setAuthorizedID(authzId);
0178 }
0179 logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
0180 } else {
0181 throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
0182 }
0183 }
0184 }
0185 }
0186
0187
0188 public static String encodeIdentifier(String identifier) {
0189 Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
0190 return getBase64EncodedString(identifier);
0191 }
0192
0193
0194 public static char[] encodePassword(String password) {
0195 Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
0196 return getBase64EncodedString(password).toCharArray();
0197 }
0198
0199
0200 private static String getBase64EncodedString(String str) {
0201 ByteBuf byteBuf = null;
0202 ByteBuf encodedByteBuf = null;
0203 try {
0204 byteBuf = Unpooled.wrappedBuffer(str.getBytes(StandardCharsets.UTF_8));
0205 encodedByteBuf = Base64.encode(byteBuf);
0206 return encodedByteBuf.toString(StandardCharsets.UTF_8);
0207 } finally {
0208
0209 if (byteBuf != null) {
0210 byteBuf.release();
0211 if (encodedByteBuf != null) {
0212 encodedByteBuf.release();
0213 }
0214 }
0215 }
0216 }
0217 }