0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 package org.apache.spark.network.crypto;
0018
0019 import javax.crypto.spec.SecretKeySpec;
0020 import java.io.IOException;
0021 import java.nio.channels.ReadableByteChannel;
0022 import java.nio.channels.WritableByteChannel;
0023
0024 import io.netty.buffer.ByteBuf;
0025 import io.netty.buffer.Unpooled;
0026 import io.netty.channel.embedded.EmbeddedChannel;
0027 import org.apache.commons.crypto.stream.CryptoInputStream;
0028 import org.apache.commons.crypto.stream.CryptoOutputStream;
0029 import org.apache.spark.network.util.MapConfigProvider;
0030 import org.apache.spark.network.util.TransportConf;
0031 import org.hamcrest.CoreMatchers;
0032 import org.junit.Test;
0033
0034 import static org.junit.Assert.assertEquals;
0035 import static org.junit.Assert.assertFalse;
0036 import static org.junit.Assert.assertThat;
0037 import static org.junit.Assert.fail;
0038 import static org.mockito.ArgumentMatchers.any;
0039 import static org.mockito.ArgumentMatchers.anyInt;
0040 import static org.mockito.Mockito.mock;
0041 import static org.mockito.Mockito.when;
0042
0043 public class TransportCipherSuite {
0044
0045 @Test
0046 public void testBufferNotLeaksOnInternalError() throws IOException {
0047 String algorithm = "TestAlgorithm";
0048 TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY);
0049 TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(),
0050 new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) {
0051
0052 @Override
0053 CryptoOutputStream createOutputStream(WritableByteChannel ch) {
0054 return null;
0055 }
0056
0057 @Override
0058 CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
0059 CryptoInputStream mockInputStream = mock(CryptoInputStream.class);
0060 when(mockInputStream.read(any(byte[].class), anyInt(), anyInt()))
0061 .thenThrow(new InternalError());
0062 return mockInputStream;
0063 }
0064 };
0065
0066 EmbeddedChannel channel = new EmbeddedChannel();
0067 cipher.addToChannel(channel);
0068
0069 ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
0070 ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
0071
0072 try {
0073 channel.writeInbound(buffer);
0074 fail("Should have raised InternalError");
0075 } catch (InternalError expected) {
0076
0077 assertEquals(0, buffer.refCnt());
0078 }
0079
0080 try {
0081 channel.writeInbound(buffer2);
0082 fail("Should have raised an exception");
0083 } catch (Throwable expected) {
0084 assertThat(expected, CoreMatchers.instanceOf(IOException.class));
0085 assertEquals(0, buffer2.refCnt());
0086 }
0087
0088
0089 assertFalse(channel.finish());
0090 }
0091 }