0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.crypto;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.nio.channels.ReadableByteChannel;
0023 import java.nio.channels.WritableByteChannel;
0024 import java.util.Properties;
0025 import javax.crypto.spec.SecretKeySpec;
0026 import javax.crypto.spec.IvParameterSpec;
0027
0028 import com.google.common.annotations.VisibleForTesting;
0029 import com.google.common.base.Preconditions;
0030 import io.netty.buffer.ByteBuf;
0031 import io.netty.buffer.Unpooled;
0032 import io.netty.channel.*;
0033 import org.apache.commons.crypto.stream.CryptoInputStream;
0034 import org.apache.commons.crypto.stream.CryptoOutputStream;
0035
0036 import org.apache.spark.network.util.AbstractFileRegion;
0037 import org.apache.spark.network.util.ByteArrayReadableChannel;
0038 import org.apache.spark.network.util.ByteArrayWritableChannel;
0039
0040
0041
0042
0043 public class TransportCipher {
0044 @VisibleForTesting
0045 static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
0046 private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
0047 @VisibleForTesting
0048 static final int STREAM_BUFFER_SIZE = 1024 * 32;
0049
0050 private final Properties conf;
0051 private final String cipher;
0052 private final SecretKeySpec key;
0053 private final byte[] inIv;
0054 private final byte[] outIv;
0055
0056 public TransportCipher(
0057 Properties conf,
0058 String cipher,
0059 SecretKeySpec key,
0060 byte[] inIv,
0061 byte[] outIv) {
0062 this.conf = conf;
0063 this.cipher = cipher;
0064 this.key = key;
0065 this.inIv = inIv;
0066 this.outIv = outIv;
0067 }
0068
0069 public String getCipherTransformation() {
0070 return cipher;
0071 }
0072
0073 @VisibleForTesting
0074 SecretKeySpec getKey() {
0075 return key;
0076 }
0077
0078
0079 public byte[] getInputIv() {
0080 return inIv;
0081 }
0082
0083
0084 public byte[] getOutputIv() {
0085 return outIv;
0086 }
0087
0088 @VisibleForTesting
0089 CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
0090 return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
0091 }
0092
0093 @VisibleForTesting
0094 CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
0095 return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv));
0096 }
0097
0098
0099
0100
0101
0102
0103
0104 public void addToChannel(Channel ch) throws IOException {
0105 ch.pipeline()
0106 .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this))
0107 .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
0108 }
0109
0110 @VisibleForTesting
0111 static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
0112 private final ByteArrayWritableChannel byteChannel;
0113 private final CryptoOutputStream cos;
0114 private boolean isCipherValid;
0115
0116 EncryptionHandler(TransportCipher cipher) throws IOException {
0117 byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
0118 cos = cipher.createOutputStream(byteChannel);
0119 isCipherValid = true;
0120 }
0121
0122 @Override
0123 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
0124 throws Exception {
0125 ctx.write(createEncryptedMessage(msg), promise);
0126 }
0127
0128 @VisibleForTesting
0129 EncryptedMessage createEncryptedMessage(Object msg) {
0130 return new EncryptedMessage(this, cos, msg, byteChannel);
0131 }
0132
0133 @Override
0134 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
0135 try {
0136 if (isCipherValid) {
0137 cos.close();
0138 }
0139 } finally {
0140 super.close(ctx, promise);
0141 }
0142 }
0143
0144
0145
0146
0147
0148 void reportError() {
0149 this.isCipherValid = false;
0150 }
0151
0152 boolean isCipherValid() {
0153 return isCipherValid;
0154 }
0155 }
0156
0157 private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
0158 private final CryptoInputStream cis;
0159 private final ByteArrayReadableChannel byteChannel;
0160 private boolean isCipherValid;
0161
0162 DecryptionHandler(TransportCipher cipher) throws IOException {
0163 byteChannel = new ByteArrayReadableChannel();
0164 cis = cipher.createInputStream(byteChannel);
0165 isCipherValid = true;
0166 }
0167
0168 @Override
0169 public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
0170 ByteBuf buffer = (ByteBuf) data;
0171
0172 try {
0173 if (!isCipherValid) {
0174 throw new IOException("Cipher is in invalid state.");
0175 }
0176 byte[] decryptedData = new byte[buffer.readableBytes()];
0177 byteChannel.feedData(buffer);
0178
0179 int offset = 0;
0180 while (offset < decryptedData.length) {
0181
0182 try {
0183 offset += cis.read(decryptedData, offset, decryptedData.length - offset);
0184 } catch (InternalError ie) {
0185 isCipherValid = false;
0186 throw ie;
0187 }
0188 }
0189
0190 ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
0191 } finally {
0192 buffer.release();
0193 }
0194 }
0195
0196 @Override
0197 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
0198
0199
0200
0201
0202
0203 try {
0204 if (isCipherValid) {
0205 cis.close();
0206 }
0207 } finally {
0208 super.handlerRemoved(ctx);
0209 }
0210 }
0211 }
0212
0213 @VisibleForTesting
0214 static class EncryptedMessage extends AbstractFileRegion {
0215 private final boolean isByteBuf;
0216 private final ByteBuf buf;
0217 private final FileRegion region;
0218 private final CryptoOutputStream cos;
0219 private final EncryptionHandler handler;
0220 private final long count;
0221 private long transferred;
0222
0223
0224
0225
0226 private ByteArrayWritableChannel byteEncChannel;
0227 private ByteArrayWritableChannel byteRawChannel;
0228
0229 private ByteBuffer currentEncrypted;
0230
0231 EncryptedMessage(
0232 EncryptionHandler handler,
0233 CryptoOutputStream cos,
0234 Object msg,
0235 ByteArrayWritableChannel ch) {
0236 Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
0237 "Unrecognized message type: %s", msg.getClass().getName());
0238 this.handler = handler;
0239 this.isByteBuf = msg instanceof ByteBuf;
0240 this.buf = isByteBuf ? (ByteBuf) msg : null;
0241 this.region = isByteBuf ? null : (FileRegion) msg;
0242 this.transferred = 0;
0243 this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
0244 this.cos = cos;
0245 this.byteEncChannel = ch;
0246 this.count = isByteBuf ? buf.readableBytes() : region.count();
0247 }
0248
0249 @Override
0250 public long count() {
0251 return count;
0252 }
0253
0254 @Override
0255 public long position() {
0256 return 0;
0257 }
0258
0259 @Override
0260 public long transferred() {
0261 return transferred;
0262 }
0263
0264 @Override
0265 public EncryptedMessage touch(Object o) {
0266 super.touch(o);
0267 if (region != null) {
0268 region.touch(o);
0269 }
0270 if (buf != null) {
0271 buf.touch(o);
0272 }
0273 return this;
0274 }
0275
0276 @Override
0277 public EncryptedMessage retain(int increment) {
0278 super.retain(increment);
0279 if (region != null) {
0280 region.retain(increment);
0281 }
0282 if (buf != null) {
0283 buf.retain(increment);
0284 }
0285 return this;
0286 }
0287
0288 @Override
0289 public boolean release(int decrement) {
0290 if (region != null) {
0291 region.release(decrement);
0292 }
0293 if (buf != null) {
0294 buf.release(decrement);
0295 }
0296 return super.release(decrement);
0297 }
0298
0299 @Override
0300 public long transferTo(WritableByteChannel target, long position) throws IOException {
0301 Preconditions.checkArgument(position == transferred(), "Invalid position.");
0302
0303 if (transferred == count) {
0304 return 0;
0305 }
0306
0307 long totalBytesWritten = 0L;
0308 do {
0309 if (currentEncrypted == null) {
0310 encryptMore();
0311 }
0312
0313 long remaining = currentEncrypted.remaining();
0314 if (remaining == 0) {
0315
0316
0317
0318 currentEncrypted = null;
0319 byteEncChannel.reset();
0320 return totalBytesWritten;
0321 }
0322
0323 long bytesWritten = target.write(currentEncrypted);
0324 totalBytesWritten += bytesWritten;
0325 transferred += bytesWritten;
0326 if (bytesWritten < remaining) {
0327
0328 break;
0329 }
0330 currentEncrypted = null;
0331 byteEncChannel.reset();
0332 } while (transferred < count);
0333
0334 return totalBytesWritten;
0335 }
0336
0337 private void encryptMore() throws IOException {
0338 if (!handler.isCipherValid()) {
0339 throw new IOException("Cipher is in invalid state.");
0340 }
0341 byteRawChannel.reset();
0342
0343 if (isByteBuf) {
0344 int copied = byteRawChannel.write(buf.nioBuffer());
0345 buf.skipBytes(copied);
0346 } else {
0347 region.transferTo(byteRawChannel, region.transferred());
0348 }
0349
0350 try {
0351 cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
0352 cos.flush();
0353 } catch (InternalError ie) {
0354 handler.reportError();
0355 throw ie;
0356 }
0357
0358 currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
0359 0, byteEncChannel.length());
0360 }
0361
0362 @Override
0363 protected void deallocate() {
0364 byteRawChannel.reset();
0365 byteEncChannel.reset();
0366 if (region != null) {
0367 region.release();
0368 }
0369 if (buf != null) {
0370 buf.release();
0371 }
0372 }
0373 }
0374
0375 }