0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network;
0019
0020 import java.io.ByteArrayOutputStream;
0021 import java.io.File;
0022 import java.io.FileOutputStream;
0023 import java.io.IOException;
0024 import java.io.OutputStream;
0025 import java.nio.ByteBuffer;
0026 import java.util.ArrayList;
0027 import java.util.Arrays;
0028 import java.util.List;
0029 import java.util.concurrent.Executors;
0030 import java.util.concurrent.ExecutorService;
0031 import java.util.concurrent.TimeUnit;
0032
0033 import com.google.common.io.Files;
0034 import org.junit.AfterClass;
0035 import org.junit.BeforeClass;
0036 import org.junit.Test;
0037 import static org.junit.Assert.*;
0038
0039 import org.apache.spark.network.buffer.ManagedBuffer;
0040 import org.apache.spark.network.client.RpcResponseCallback;
0041 import org.apache.spark.network.client.StreamCallback;
0042 import org.apache.spark.network.client.TransportClient;
0043 import org.apache.spark.network.client.TransportClientFactory;
0044 import org.apache.spark.network.server.RpcHandler;
0045 import org.apache.spark.network.server.StreamManager;
0046 import org.apache.spark.network.server.TransportServer;
0047 import org.apache.spark.network.util.MapConfigProvider;
0048 import org.apache.spark.network.util.TransportConf;
0049
0050 public class StreamSuite {
0051 private static final String[] STREAMS = StreamTestHelper.STREAMS;
0052 private static StreamTestHelper testData;
0053
0054 private static TransportContext context;
0055 private static TransportServer server;
0056 private static TransportClientFactory clientFactory;
0057
0058 private static ByteBuffer createBuffer(int bufSize) {
0059 ByteBuffer buf = ByteBuffer.allocate(bufSize);
0060 for (int i = 0; i < bufSize; i ++) {
0061 buf.put((byte) i);
0062 }
0063 buf.flip();
0064 return buf;
0065 }
0066
0067 @BeforeClass
0068 public static void setUp() throws Exception {
0069 testData = new StreamTestHelper();
0070
0071 final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0072 final StreamManager streamManager = new StreamManager() {
0073 @Override
0074 public ManagedBuffer getChunk(long streamId, int chunkIndex) {
0075 throw new UnsupportedOperationException();
0076 }
0077
0078 @Override
0079 public ManagedBuffer openStream(String streamId) {
0080 return testData.openStream(conf, streamId);
0081 }
0082 };
0083 RpcHandler handler = new RpcHandler() {
0084 @Override
0085 public void receive(
0086 TransportClient client,
0087 ByteBuffer message,
0088 RpcResponseCallback callback) {
0089 throw new UnsupportedOperationException();
0090 }
0091
0092 @Override
0093 public StreamManager getStreamManager() {
0094 return streamManager;
0095 }
0096 };
0097 context = new TransportContext(conf, handler);
0098 server = context.createServer();
0099 clientFactory = context.createClientFactory();
0100 }
0101
0102 @AfterClass
0103 public static void tearDown() {
0104 server.close();
0105 clientFactory.close();
0106 testData.cleanup();
0107 context.close();
0108 }
0109
0110 @Test
0111 public void testZeroLengthStream() throws Throwable {
0112 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0113 try {
0114 StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5));
0115 task.run();
0116 task.check();
0117 } finally {
0118 client.close();
0119 }
0120 }
0121
0122 @Test
0123 public void testSingleStream() throws Throwable {
0124 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0125 try {
0126 StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5));
0127 task.run();
0128 task.check();
0129 } finally {
0130 client.close();
0131 }
0132 }
0133
0134 @Test
0135 public void testMultipleStreams() throws Throwable {
0136 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0137 try {
0138 for (int i = 0; i < 20; i++) {
0139 StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
0140 TimeUnit.SECONDS.toMillis(5));
0141 task.run();
0142 task.check();
0143 }
0144 } finally {
0145 client.close();
0146 }
0147 }
0148
0149 @Test
0150 public void testConcurrentStreams() throws Throwable {
0151 ExecutorService executor = Executors.newFixedThreadPool(20);
0152 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0153
0154 try {
0155 List<StreamTask> tasks = new ArrayList<>();
0156 for (int i = 0; i < 20; i++) {
0157 StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
0158 TimeUnit.SECONDS.toMillis(20));
0159 tasks.add(task);
0160 executor.submit(task);
0161 }
0162
0163 executor.shutdown();
0164 assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS));
0165 for (StreamTask task : tasks) {
0166 task.check();
0167 }
0168 } finally {
0169 executor.shutdownNow();
0170 client.close();
0171 }
0172 }
0173
0174 private static class StreamTask implements Runnable {
0175
0176 private final TransportClient client;
0177 private final String streamId;
0178 private final long timeoutMs;
0179 private Throwable error;
0180
0181 StreamTask(TransportClient client, String streamId, long timeoutMs) {
0182 this.client = client;
0183 this.streamId = streamId;
0184 this.timeoutMs = timeoutMs;
0185 }
0186
0187 @Override
0188 public void run() {
0189 ByteBuffer srcBuffer = null;
0190 OutputStream out = null;
0191 File outFile = null;
0192 try {
0193 ByteArrayOutputStream baos = null;
0194
0195 switch (streamId) {
0196 case "largeBuffer":
0197 baos = new ByteArrayOutputStream();
0198 out = baos;
0199 srcBuffer = testData.largeBuffer;
0200 break;
0201 case "smallBuffer":
0202 baos = new ByteArrayOutputStream();
0203 out = baos;
0204 srcBuffer = testData.smallBuffer;
0205 break;
0206 case "file":
0207 outFile = File.createTempFile("data", ".tmp", testData.tempDir);
0208 out = new FileOutputStream(outFile);
0209 break;
0210 case "emptyBuffer":
0211 baos = new ByteArrayOutputStream();
0212 out = baos;
0213 srcBuffer = testData.emptyBuffer;
0214 break;
0215 default:
0216 throw new IllegalArgumentException(streamId);
0217 }
0218
0219 TestCallback callback = new TestCallback(out);
0220 client.stream(streamId, callback);
0221 callback.waitForCompletion(timeoutMs);
0222
0223 if (srcBuffer == null) {
0224 assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile));
0225 } else {
0226 ByteBuffer base;
0227 synchronized (srcBuffer) {
0228 base = srcBuffer.duplicate();
0229 }
0230 byte[] result = baos.toByteArray();
0231 byte[] expected = new byte[base.remaining()];
0232 base.get(expected);
0233 assertEquals(expected.length, result.length);
0234 assertTrue("buffers don't match", Arrays.equals(expected, result));
0235 }
0236 } catch (Throwable t) {
0237 error = t;
0238 } finally {
0239 if (out != null) {
0240 try {
0241 out.close();
0242 } catch (Exception e) {
0243
0244 }
0245 }
0246 if (outFile != null) {
0247 outFile.delete();
0248 }
0249 }
0250 }
0251
0252 public void check() throws Throwable {
0253 if (error != null) {
0254 throw error;
0255 }
0256 }
0257 }
0258
0259 static class TestCallback implements StreamCallback {
0260
0261 private final OutputStream out;
0262 public volatile boolean completed;
0263 public volatile Throwable error;
0264
0265 TestCallback(OutputStream out) {
0266 this.out = out;
0267 this.completed = false;
0268 }
0269
0270 @Override
0271 public void onData(String streamId, ByteBuffer buf) throws IOException {
0272 byte[] tmp = new byte[buf.remaining()];
0273 buf.get(tmp);
0274 out.write(tmp);
0275 }
0276
0277 @Override
0278 public void onComplete(String streamId) throws IOException {
0279 out.close();
0280 synchronized (this) {
0281 completed = true;
0282 notifyAll();
0283 }
0284 }
0285
0286 @Override
0287 public void onFailure(String streamId, Throwable cause) {
0288 error = cause;
0289 synchronized (this) {
0290 completed = true;
0291 notifyAll();
0292 }
0293 }
0294
0295 void waitForCompletion(long timeoutMs) {
0296 long now = System.currentTimeMillis();
0297 long deadline = now + timeoutMs;
0298 synchronized (this) {
0299 while (!completed && now < deadline) {
0300 try {
0301 wait(deadline - now);
0302 } catch (InterruptedException ie) {
0303 throw new RuntimeException(ie);
0304 }
0305 now = System.currentTimeMillis();
0306 }
0307 }
0308 assertTrue("Timed out waiting for stream.", completed);
0309 assertNull(error);
0310 }
0311 }
0312
0313 }