0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.sasl;
0019
0020 import java.util.Map;
0021 import javax.security.auth.callback.Callback;
0022 import javax.security.auth.callback.CallbackHandler;
0023 import javax.security.auth.callback.NameCallback;
0024 import javax.security.auth.callback.PasswordCallback;
0025 import javax.security.auth.callback.UnsupportedCallbackException;
0026 import javax.security.sasl.RealmCallback;
0027 import javax.security.sasl.RealmChoiceCallback;
0028 import javax.security.sasl.Sasl;
0029 import javax.security.sasl.SaslClient;
0030 import javax.security.sasl.SaslException;
0031
0032 import com.google.common.base.Throwables;
0033 import com.google.common.collect.ImmutableMap;
0034 import org.slf4j.Logger;
0035 import org.slf4j.LoggerFactory;
0036
0037 import static org.apache.spark.network.sasl.SparkSaslServer.*;
0038
0039
0040
0041
0042
0043
0044 public class SparkSaslClient implements SaslEncryptionBackend {
0045 private static final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
0046
0047 private final String secretKeyId;
0048 private final SecretKeyHolder secretKeyHolder;
0049 private final String expectedQop;
0050 private SaslClient saslClient;
0051
0052 public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
0053 this.secretKeyId = secretKeyId;
0054 this.secretKeyHolder = secretKeyHolder;
0055 this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
0056
0057 Map<String, String> saslProps = ImmutableMap.<String, String>builder()
0058 .put(Sasl.QOP, expectedQop)
0059 .build();
0060 try {
0061 this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
0062 saslProps, new ClientCallbackHandler());
0063 } catch (SaslException e) {
0064 throw Throwables.propagate(e);
0065 }
0066 }
0067
0068
0069 public synchronized byte[] firstToken() {
0070 if (saslClient != null && saslClient.hasInitialResponse()) {
0071 try {
0072 return saslClient.evaluateChallenge(new byte[0]);
0073 } catch (SaslException e) {
0074 throw Throwables.propagate(e);
0075 }
0076 } else {
0077 return new byte[0];
0078 }
0079 }
0080
0081
0082 public synchronized boolean isComplete() {
0083 return saslClient != null && saslClient.isComplete();
0084 }
0085
0086
0087 public Object getNegotiatedProperty(String name) {
0088 return saslClient.getNegotiatedProperty(name);
0089 }
0090
0091
0092
0093
0094
0095
0096 public synchronized byte[] response(byte[] token) {
0097 try {
0098 return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
0099 } catch (SaslException e) {
0100 throw Throwables.propagate(e);
0101 }
0102 }
0103
0104
0105
0106
0107
0108 @Override
0109 public synchronized void dispose() {
0110 if (saslClient != null) {
0111 try {
0112 saslClient.dispose();
0113 } catch (SaslException e) {
0114
0115 } finally {
0116 saslClient = null;
0117 }
0118 }
0119 }
0120
0121
0122
0123
0124
0125 private class ClientCallbackHandler implements CallbackHandler {
0126 @Override
0127 public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
0128
0129 for (Callback callback : callbacks) {
0130 if (callback instanceof NameCallback) {
0131 logger.trace("SASL client callback: setting username");
0132 NameCallback nc = (NameCallback) callback;
0133 nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
0134 } else if (callback instanceof PasswordCallback) {
0135 logger.trace("SASL client callback: setting password");
0136 PasswordCallback pc = (PasswordCallback) callback;
0137 pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
0138 } else if (callback instanceof RealmCallback) {
0139 logger.trace("SASL client callback: setting realm");
0140 RealmCallback rc = (RealmCallback) callback;
0141 rc.setText(rc.getDefaultText());
0142 } else if (callback instanceof RealmChoiceCallback) {
0143
0144 } else {
0145 throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
0146 }
0147 }
0148 }
0149 }
0150
0151 @Override
0152 public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
0153 return saslClient.wrap(data, offset, len);
0154 }
0155
0156 @Override
0157 public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
0158 return saslClient.unwrap(data, offset, len);
0159 }
0160
0161 }