Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.sasl;
0019 
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.nio.channels.WritableByteChannel;
0023 import java.util.List;
0024 
0025 import com.google.common.annotations.VisibleForTesting;
0026 import com.google.common.base.Preconditions;
0027 import io.netty.buffer.ByteBuf;
0028 import io.netty.buffer.Unpooled;
0029 import io.netty.channel.Channel;
0030 import io.netty.channel.ChannelHandlerContext;
0031 import io.netty.channel.ChannelOutboundHandlerAdapter;
0032 import io.netty.channel.ChannelPromise;
0033 import io.netty.channel.FileRegion;
0034 import io.netty.handler.codec.MessageToMessageDecoder;
0035 
0036 import org.apache.spark.network.util.AbstractFileRegion;
0037 import org.apache.spark.network.util.ByteArrayWritableChannel;
0038 import org.apache.spark.network.util.NettyUtils;
0039 
0040 /**
0041  * Provides SASL-based encryption for transport channels. The single method exposed by this
0042  * class installs the needed channel handlers on a connected channel.
0043  */
0044 class SaslEncryption {
0045 
0046   @VisibleForTesting
0047   static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
0048 
0049   /**
0050    * Adds channel handlers that perform encryption / decryption of data using SASL.
0051    *
0052    * @param channel The channel.
0053    * @param backend The SASL backend.
0054    * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
0055    *                             memory usage.
0056    */
0057   static void addToChannel(
0058       Channel channel,
0059       SaslEncryptionBackend backend,
0060       int maxOutboundBlockSize) {
0061     channel.pipeline()
0062       .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
0063       .addFirst("saslDecryption", new DecryptionHandler(backend))
0064       .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
0065   }
0066 
0067   private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
0068 
0069     private final int maxOutboundBlockSize;
0070     private final SaslEncryptionBackend backend;
0071 
0072     EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
0073       this.backend = backend;
0074       this.maxOutboundBlockSize = maxOutboundBlockSize;
0075     }
0076 
0077     /**
0078      * Wrap the incoming message in an implementation that will perform encryption lazily. This is
0079      * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
0080      * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
0081      * does not guarantee any ordering.
0082      */
0083     @Override
0084     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
0085       throws Exception {
0086 
0087       ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
0088     }
0089 
0090     @Override
0091     public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
0092       try {
0093         backend.dispose();
0094       } finally {
0095         super.handlerRemoved(ctx);
0096       }
0097     }
0098 
0099   }
0100 
0101   private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
0102 
0103     private final SaslEncryptionBackend backend;
0104 
0105     DecryptionHandler(SaslEncryptionBackend backend) {
0106       this.backend = backend;
0107     }
0108 
0109     @Override
0110     protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
0111       throws Exception {
0112 
0113       byte[] data;
0114       int offset;
0115       int length = msg.readableBytes();
0116       if (msg.hasArray()) {
0117         data = msg.array();
0118         offset = msg.arrayOffset();
0119         msg.skipBytes(length);
0120       } else {
0121         data = new byte[length];
0122         msg.readBytes(data);
0123         offset = 0;
0124       }
0125 
0126       out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
0127     }
0128 
0129   }
0130 
0131   @VisibleForTesting
0132   static class EncryptedMessage extends AbstractFileRegion {
0133 
0134     private final SaslEncryptionBackend backend;
0135     private final boolean isByteBuf;
0136     private final ByteBuf buf;
0137     private final FileRegion region;
0138     private final int maxOutboundBlockSize;
0139 
0140     /**
0141      * A channel used to buffer input data for encryption. The channel has an upper size bound
0142      * so that if the input is larger than the allowed buffer, it will be broken into multiple
0143      * chunks. Made non-final to enable lazy initialization, which saves memory.
0144      */
0145     private ByteArrayWritableChannel byteChannel;
0146 
0147     private ByteBuf currentHeader;
0148     private ByteBuffer currentChunk;
0149     private long currentChunkSize;
0150     private long currentReportedBytes;
0151     private long unencryptedChunkSize;
0152     private long transferred;
0153 
0154     EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
0155       Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
0156         "Unrecognized message type: %s", msg.getClass().getName());
0157       this.backend = backend;
0158       this.isByteBuf = msg instanceof ByteBuf;
0159       this.buf = isByteBuf ? (ByteBuf) msg : null;
0160       this.region = isByteBuf ? null : (FileRegion) msg;
0161       this.maxOutboundBlockSize = maxOutboundBlockSize;
0162     }
0163 
0164     /**
0165      * Returns the size of the original (unencrypted) message.
0166      *
0167      * This makes assumptions about how netty treats FileRegion instances, because there's no way
0168      * to know beforehand what will be the size of the encrypted message. Namely, it assumes
0169      * that netty will try to transfer data from this message while
0170      * <code>transferred() < count()</code>. So these two methods return, technically, wrong data,
0171      * but netty doesn't know better.
0172      */
0173     @Override
0174     public long count() {
0175       return isByteBuf ? buf.readableBytes() : region.count();
0176     }
0177 
0178     @Override
0179     public long position() {
0180       return 0;
0181     }
0182 
0183     /**
0184      * Returns an approximation of the amount of data transferred. See {@link #count()}.
0185      */
0186     @Override
0187     public long transferred() {
0188       return transferred;
0189     }
0190 
0191     @Override
0192     public EncryptedMessage touch(Object o) {
0193       super.touch(o);
0194       if (buf != null) {
0195         buf.touch(o);
0196       }
0197       if (region != null) {
0198         region.touch(o);
0199       }
0200       return this;
0201     }
0202 
0203     @Override
0204     public EncryptedMessage retain(int increment) {
0205       super.retain(increment);
0206       if (buf != null) {
0207         buf.retain(increment);
0208       }
0209       if (region != null) {
0210         region.retain(increment);
0211       }
0212       return this;
0213     }
0214 
0215     @Override
0216     public boolean release(int decrement) {
0217       if (region != null) {
0218         region.release(decrement);
0219       }
0220       if (buf != null) {
0221         buf.release(decrement);
0222       }
0223       return super.release(decrement);
0224     }
0225 
0226     /**
0227      * Transfers data from the original message to the channel, encrypting it in the process.
0228      *
0229      * This method also breaks down the original message into smaller chunks when needed. This
0230      * is done to keep memory usage under control. This avoids having to copy the whole message
0231      * data into memory at once, and can avoid ballooning memory usage when transferring large
0232      * messages such as shuffle blocks.
0233      *
0234      * The {@link #transferred()} counter also behaves a little funny, in that it won't go forward
0235      * until a whole chunk has been written. This is done because the code can't use the actual
0236      * number of bytes written to the channel as the transferred count (see {@link #count()}).
0237      * Instead, once an encrypted chunk is written to the output (including its header), the
0238      * size of the original block will be added to the {@link #transferred()} amount.
0239      */
0240     @Override
0241     public long transferTo(final WritableByteChannel target, final long position)
0242       throws IOException {
0243 
0244       Preconditions.checkArgument(position == transferred(), "Invalid position.");
0245 
0246       long reportedWritten = 0L;
0247       long actuallyWritten = 0L;
0248       do {
0249         if (currentChunk == null) {
0250           nextChunk();
0251         }
0252 
0253         if (currentHeader.readableBytes() > 0) {
0254           int bytesWritten = target.write(currentHeader.nioBuffer());
0255           currentHeader.skipBytes(bytesWritten);
0256           actuallyWritten += bytesWritten;
0257           if (currentHeader.readableBytes() > 0) {
0258             // Break out of loop if there are still header bytes left to write.
0259             break;
0260           }
0261         }
0262 
0263         actuallyWritten += target.write(currentChunk);
0264         if (!currentChunk.hasRemaining()) {
0265           // Only update the count of written bytes once a full chunk has been written.
0266           // See method javadoc.
0267           long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
0268           reportedWritten += chunkBytesRemaining;
0269           transferred += chunkBytesRemaining;
0270           currentHeader.release();
0271           currentHeader = null;
0272           currentChunk = null;
0273           currentChunkSize = 0;
0274           currentReportedBytes = 0;
0275         }
0276       } while (currentChunk == null && transferred() + reportedWritten < count());
0277 
0278       // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
0279       // we return 1 until we can (i.e. until the reported count would actually match the size
0280       // of the current chunk), at which point we resort to returning 0 so that the counts still
0281       // match, at the cost of some performance. That situation should be rare, though.
0282       if (reportedWritten != 0L) {
0283         return reportedWritten;
0284       }
0285 
0286       if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
0287         transferred += 1L;
0288         currentReportedBytes += 1L;
0289         return 1L;
0290       }
0291 
0292       return 0L;
0293     }
0294 
0295     private void nextChunk() throws IOException {
0296       if (byteChannel == null) {
0297         byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
0298       }
0299       byteChannel.reset();
0300       if (isByteBuf) {
0301         int copied = byteChannel.write(buf.nioBuffer());
0302         buf.skipBytes(copied);
0303       } else {
0304         region.transferTo(byteChannel, region.transferred());
0305       }
0306 
0307       byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
0308       this.currentChunk = ByteBuffer.wrap(encrypted);
0309       this.currentChunkSize = encrypted.length;
0310       this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
0311       this.unencryptedChunkSize = byteChannel.length();
0312     }
0313 
0314     @Override
0315     protected void deallocate() {
0316       if (currentHeader != null) {
0317         currentHeader.release();
0318       }
0319       if (buf != null) {
0320         buf.release();
0321       }
0322       if (region != null) {
0323         region.release();
0324       }
0325     }
0326 
0327   }
0328 
0329 }