0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.sasl;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.nio.channels.WritableByteChannel;
0023 import java.util.List;
0024
0025 import com.google.common.annotations.VisibleForTesting;
0026 import com.google.common.base.Preconditions;
0027 import io.netty.buffer.ByteBuf;
0028 import io.netty.buffer.Unpooled;
0029 import io.netty.channel.Channel;
0030 import io.netty.channel.ChannelHandlerContext;
0031 import io.netty.channel.ChannelOutboundHandlerAdapter;
0032 import io.netty.channel.ChannelPromise;
0033 import io.netty.channel.FileRegion;
0034 import io.netty.handler.codec.MessageToMessageDecoder;
0035
0036 import org.apache.spark.network.util.AbstractFileRegion;
0037 import org.apache.spark.network.util.ByteArrayWritableChannel;
0038 import org.apache.spark.network.util.NettyUtils;
0039
0040
0041
0042
0043
0044 class SaslEncryption {
0045
0046 @VisibleForTesting
0047 static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057 static void addToChannel(
0058 Channel channel,
0059 SaslEncryptionBackend backend,
0060 int maxOutboundBlockSize) {
0061 channel.pipeline()
0062 .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
0063 .addFirst("saslDecryption", new DecryptionHandler(backend))
0064 .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
0065 }
0066
0067 private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
0068
0069 private final int maxOutboundBlockSize;
0070 private final SaslEncryptionBackend backend;
0071
0072 EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
0073 this.backend = backend;
0074 this.maxOutboundBlockSize = maxOutboundBlockSize;
0075 }
0076
0077
0078
0079
0080
0081
0082
0083 @Override
0084 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
0085 throws Exception {
0086
0087 ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
0088 }
0089
0090 @Override
0091 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
0092 try {
0093 backend.dispose();
0094 } finally {
0095 super.handlerRemoved(ctx);
0096 }
0097 }
0098
0099 }
0100
0101 private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
0102
0103 private final SaslEncryptionBackend backend;
0104
0105 DecryptionHandler(SaslEncryptionBackend backend) {
0106 this.backend = backend;
0107 }
0108
0109 @Override
0110 protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
0111 throws Exception {
0112
0113 byte[] data;
0114 int offset;
0115 int length = msg.readableBytes();
0116 if (msg.hasArray()) {
0117 data = msg.array();
0118 offset = msg.arrayOffset();
0119 msg.skipBytes(length);
0120 } else {
0121 data = new byte[length];
0122 msg.readBytes(data);
0123 offset = 0;
0124 }
0125
0126 out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
0127 }
0128
0129 }
0130
0131 @VisibleForTesting
0132 static class EncryptedMessage extends AbstractFileRegion {
0133
0134 private final SaslEncryptionBackend backend;
0135 private final boolean isByteBuf;
0136 private final ByteBuf buf;
0137 private final FileRegion region;
0138 private final int maxOutboundBlockSize;
0139
0140
0141
0142
0143
0144
0145 private ByteArrayWritableChannel byteChannel;
0146
0147 private ByteBuf currentHeader;
0148 private ByteBuffer currentChunk;
0149 private long currentChunkSize;
0150 private long currentReportedBytes;
0151 private long unencryptedChunkSize;
0152 private long transferred;
0153
0154 EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
0155 Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
0156 "Unrecognized message type: %s", msg.getClass().getName());
0157 this.backend = backend;
0158 this.isByteBuf = msg instanceof ByteBuf;
0159 this.buf = isByteBuf ? (ByteBuf) msg : null;
0160 this.region = isByteBuf ? null : (FileRegion) msg;
0161 this.maxOutboundBlockSize = maxOutboundBlockSize;
0162 }
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173 @Override
0174 public long count() {
0175 return isByteBuf ? buf.readableBytes() : region.count();
0176 }
0177
0178 @Override
0179 public long position() {
0180 return 0;
0181 }
0182
0183
0184
0185
0186 @Override
0187 public long transferred() {
0188 return transferred;
0189 }
0190
0191 @Override
0192 public EncryptedMessage touch(Object o) {
0193 super.touch(o);
0194 if (buf != null) {
0195 buf.touch(o);
0196 }
0197 if (region != null) {
0198 region.touch(o);
0199 }
0200 return this;
0201 }
0202
0203 @Override
0204 public EncryptedMessage retain(int increment) {
0205 super.retain(increment);
0206 if (buf != null) {
0207 buf.retain(increment);
0208 }
0209 if (region != null) {
0210 region.retain(increment);
0211 }
0212 return this;
0213 }
0214
0215 @Override
0216 public boolean release(int decrement) {
0217 if (region != null) {
0218 region.release(decrement);
0219 }
0220 if (buf != null) {
0221 buf.release(decrement);
0222 }
0223 return super.release(decrement);
0224 }
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240 @Override
0241 public long transferTo(final WritableByteChannel target, final long position)
0242 throws IOException {
0243
0244 Preconditions.checkArgument(position == transferred(), "Invalid position.");
0245
0246 long reportedWritten = 0L;
0247 long actuallyWritten = 0L;
0248 do {
0249 if (currentChunk == null) {
0250 nextChunk();
0251 }
0252
0253 if (currentHeader.readableBytes() > 0) {
0254 int bytesWritten = target.write(currentHeader.nioBuffer());
0255 currentHeader.skipBytes(bytesWritten);
0256 actuallyWritten += bytesWritten;
0257 if (currentHeader.readableBytes() > 0) {
0258
0259 break;
0260 }
0261 }
0262
0263 actuallyWritten += target.write(currentChunk);
0264 if (!currentChunk.hasRemaining()) {
0265
0266
0267 long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
0268 reportedWritten += chunkBytesRemaining;
0269 transferred += chunkBytesRemaining;
0270 currentHeader.release();
0271 currentHeader = null;
0272 currentChunk = null;
0273 currentChunkSize = 0;
0274 currentReportedBytes = 0;
0275 }
0276 } while (currentChunk == null && transferred() + reportedWritten < count());
0277
0278
0279
0280
0281
0282 if (reportedWritten != 0L) {
0283 return reportedWritten;
0284 }
0285
0286 if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
0287 transferred += 1L;
0288 currentReportedBytes += 1L;
0289 return 1L;
0290 }
0291
0292 return 0L;
0293 }
0294
0295 private void nextChunk() throws IOException {
0296 if (byteChannel == null) {
0297 byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
0298 }
0299 byteChannel.reset();
0300 if (isByteBuf) {
0301 int copied = byteChannel.write(buf.nioBuffer());
0302 buf.skipBytes(copied);
0303 } else {
0304 region.transferTo(byteChannel, region.transferred());
0305 }
0306
0307 byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
0308 this.currentChunk = ByteBuffer.wrap(encrypted);
0309 this.currentChunkSize = encrypted.length;
0310 this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
0311 this.unencryptedChunkSize = byteChannel.length();
0312 }
0313
0314 @Override
0315 protected void deallocate() {
0316 if (currentHeader != null) {
0317 currentHeader.release();
0318 }
0319 if (buf != null) {
0320 buf.release();
0321 }
0322 if (region != null) {
0323 region.release();
0324 }
0325 }
0326
0327 }
0328
0329 }