Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 package org.apache.spark.network.crypto;
0018 
0019 import javax.crypto.spec.SecretKeySpec;
0020 import java.io.IOException;
0021 import java.nio.channels.ReadableByteChannel;
0022 import java.nio.channels.WritableByteChannel;
0023 
0024 import io.netty.buffer.ByteBuf;
0025 import io.netty.buffer.Unpooled;
0026 import io.netty.channel.embedded.EmbeddedChannel;
0027 import org.apache.commons.crypto.stream.CryptoInputStream;
0028 import org.apache.commons.crypto.stream.CryptoOutputStream;
0029 import org.apache.spark.network.util.MapConfigProvider;
0030 import org.apache.spark.network.util.TransportConf;
0031 import org.hamcrest.CoreMatchers;
0032 import org.junit.Test;
0033 
0034 import static org.junit.Assert.assertEquals;
0035 import static org.junit.Assert.assertFalse;
0036 import static org.junit.Assert.assertThat;
0037 import static org.junit.Assert.fail;
0038 import static org.mockito.ArgumentMatchers.any;
0039 import static org.mockito.ArgumentMatchers.anyInt;
0040 import static org.mockito.Mockito.mock;
0041 import static org.mockito.Mockito.when;
0042 
0043 public class TransportCipherSuite {
0044 
0045   @Test
0046   public void testBufferNotLeaksOnInternalError() throws IOException {
0047     String algorithm = "TestAlgorithm";
0048     TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY);
0049     TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(),
0050       new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) {
0051 
0052       @Override
0053       CryptoOutputStream createOutputStream(WritableByteChannel ch) {
0054         return null;
0055       }
0056 
0057       @Override
0058       CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
0059         CryptoInputStream mockInputStream = mock(CryptoInputStream.class);
0060         when(mockInputStream.read(any(byte[].class), anyInt(), anyInt()))
0061           .thenThrow(new InternalError());
0062         return mockInputStream;
0063       }
0064     };
0065 
0066     EmbeddedChannel channel = new EmbeddedChannel();
0067     cipher.addToChannel(channel);
0068 
0069     ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
0070     ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
0071 
0072     try {
0073       channel.writeInbound(buffer);
0074       fail("Should have raised InternalError");
0075     } catch (InternalError expected) {
0076       // expected
0077       assertEquals(0, buffer.refCnt());
0078     }
0079 
0080     try {
0081       channel.writeInbound(buffer2);
0082       fail("Should have raised an exception");
0083     } catch (Throwable expected) {
0084       assertThat(expected, CoreMatchers.instanceOf(IOException.class));
0085       assertEquals(0, buffer2.refCnt());
0086     }
0087 
0088     // Simulate closing the connection
0089     assertFalse(channel.finish());
0090   }
0091 }