Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network;
0019 
0020 import java.io.ByteArrayOutputStream;
0021 import java.io.File;
0022 import java.io.FileOutputStream;
0023 import java.io.IOException;
0024 import java.io.OutputStream;
0025 import java.nio.ByteBuffer;
0026 import java.util.ArrayList;
0027 import java.util.Arrays;
0028 import java.util.List;
0029 import java.util.concurrent.Executors;
0030 import java.util.concurrent.ExecutorService;
0031 import java.util.concurrent.TimeUnit;
0032 
0033 import com.google.common.io.Files;
0034 import org.junit.AfterClass;
0035 import org.junit.BeforeClass;
0036 import org.junit.Test;
0037 import static org.junit.Assert.*;
0038 
0039 import org.apache.spark.network.buffer.ManagedBuffer;
0040 import org.apache.spark.network.client.RpcResponseCallback;
0041 import org.apache.spark.network.client.StreamCallback;
0042 import org.apache.spark.network.client.TransportClient;
0043 import org.apache.spark.network.client.TransportClientFactory;
0044 import org.apache.spark.network.server.RpcHandler;
0045 import org.apache.spark.network.server.StreamManager;
0046 import org.apache.spark.network.server.TransportServer;
0047 import org.apache.spark.network.util.MapConfigProvider;
0048 import org.apache.spark.network.util.TransportConf;
0049 
0050 public class StreamSuite {
0051   private static final String[] STREAMS = StreamTestHelper.STREAMS;
0052   private static StreamTestHelper testData;
0053 
0054   private static TransportContext context;
0055   private static TransportServer server;
0056   private static TransportClientFactory clientFactory;
0057 
0058   private static ByteBuffer createBuffer(int bufSize) {
0059     ByteBuffer buf = ByteBuffer.allocate(bufSize);
0060     for (int i = 0; i < bufSize; i ++) {
0061       buf.put((byte) i);
0062     }
0063     buf.flip();
0064     return buf;
0065   }
0066 
0067   @BeforeClass
0068   public static void setUp() throws Exception {
0069     testData = new StreamTestHelper();
0070 
0071     final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0072     final StreamManager streamManager = new StreamManager() {
0073       @Override
0074       public ManagedBuffer getChunk(long streamId, int chunkIndex) {
0075         throw new UnsupportedOperationException();
0076       }
0077 
0078       @Override
0079       public ManagedBuffer openStream(String streamId) {
0080         return testData.openStream(conf, streamId);
0081       }
0082     };
0083     RpcHandler handler = new RpcHandler() {
0084       @Override
0085       public void receive(
0086           TransportClient client,
0087           ByteBuffer message,
0088           RpcResponseCallback callback) {
0089         throw new UnsupportedOperationException();
0090       }
0091 
0092       @Override
0093       public StreamManager getStreamManager() {
0094         return streamManager;
0095       }
0096     };
0097     context = new TransportContext(conf, handler);
0098     server = context.createServer();
0099     clientFactory = context.createClientFactory();
0100   }
0101 
0102   @AfterClass
0103   public static void tearDown() {
0104     server.close();
0105     clientFactory.close();
0106     testData.cleanup();
0107     context.close();
0108   }
0109 
0110   @Test
0111   public void testZeroLengthStream() throws Throwable {
0112     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0113     try {
0114       StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5));
0115       task.run();
0116       task.check();
0117     } finally {
0118       client.close();
0119     }
0120   }
0121 
0122   @Test
0123   public void testSingleStream() throws Throwable {
0124     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0125     try {
0126       StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5));
0127       task.run();
0128       task.check();
0129     } finally {
0130       client.close();
0131     }
0132   }
0133 
0134   @Test
0135   public void testMultipleStreams() throws Throwable {
0136     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0137     try {
0138       for (int i = 0; i < 20; i++) {
0139         StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
0140           TimeUnit.SECONDS.toMillis(5));
0141         task.run();
0142         task.check();
0143       }
0144     } finally {
0145       client.close();
0146     }
0147   }
0148 
0149   @Test
0150   public void testConcurrentStreams() throws Throwable {
0151     ExecutorService executor = Executors.newFixedThreadPool(20);
0152     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0153 
0154     try {
0155       List<StreamTask> tasks = new ArrayList<>();
0156       for (int i = 0; i < 20; i++) {
0157         StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
0158           TimeUnit.SECONDS.toMillis(20));
0159         tasks.add(task);
0160         executor.submit(task);
0161       }
0162 
0163       executor.shutdown();
0164       assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS));
0165       for (StreamTask task : tasks) {
0166         task.check();
0167       }
0168     } finally {
0169       executor.shutdownNow();
0170       client.close();
0171     }
0172   }
0173 
0174   private static class StreamTask implements Runnable {
0175 
0176     private final TransportClient client;
0177     private final String streamId;
0178     private final long timeoutMs;
0179     private Throwable error;
0180 
0181     StreamTask(TransportClient client, String streamId, long timeoutMs) {
0182       this.client = client;
0183       this.streamId = streamId;
0184       this.timeoutMs = timeoutMs;
0185     }
0186 
0187     @Override
0188     public void run() {
0189       ByteBuffer srcBuffer = null;
0190       OutputStream out = null;
0191       File outFile = null;
0192       try {
0193         ByteArrayOutputStream baos = null;
0194 
0195         switch (streamId) {
0196           case "largeBuffer":
0197             baos = new ByteArrayOutputStream();
0198             out = baos;
0199             srcBuffer = testData.largeBuffer;
0200             break;
0201           case "smallBuffer":
0202             baos = new ByteArrayOutputStream();
0203             out = baos;
0204             srcBuffer = testData.smallBuffer;
0205             break;
0206           case "file":
0207             outFile = File.createTempFile("data", ".tmp", testData.tempDir);
0208             out = new FileOutputStream(outFile);
0209             break;
0210           case "emptyBuffer":
0211             baos = new ByteArrayOutputStream();
0212             out = baos;
0213             srcBuffer = testData.emptyBuffer;
0214             break;
0215           default:
0216             throw new IllegalArgumentException(streamId);
0217         }
0218 
0219         TestCallback callback = new TestCallback(out);
0220         client.stream(streamId, callback);
0221         callback.waitForCompletion(timeoutMs);
0222 
0223         if (srcBuffer == null) {
0224           assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile));
0225         } else {
0226           ByteBuffer base;
0227           synchronized (srcBuffer) {
0228             base = srcBuffer.duplicate();
0229           }
0230           byte[] result = baos.toByteArray();
0231           byte[] expected = new byte[base.remaining()];
0232           base.get(expected);
0233           assertEquals(expected.length, result.length);
0234           assertTrue("buffers don't match", Arrays.equals(expected, result));
0235         }
0236       } catch (Throwable t) {
0237         error = t;
0238       } finally {
0239         if (out != null) {
0240           try {
0241             out.close();
0242           } catch (Exception e) {
0243             // ignore.
0244           }
0245         }
0246         if (outFile != null) {
0247           outFile.delete();
0248         }
0249       }
0250     }
0251 
0252     public void check() throws Throwable {
0253       if (error != null) {
0254         throw error;
0255       }
0256     }
0257   }
0258 
0259   static class TestCallback implements StreamCallback {
0260 
0261     private final OutputStream out;
0262     public volatile boolean completed;
0263     public volatile Throwable error;
0264 
0265     TestCallback(OutputStream out) {
0266       this.out = out;
0267       this.completed = false;
0268     }
0269 
0270     @Override
0271     public void onData(String streamId, ByteBuffer buf) throws IOException {
0272       byte[] tmp = new byte[buf.remaining()];
0273       buf.get(tmp);
0274       out.write(tmp);
0275     }
0276 
0277     @Override
0278     public void onComplete(String streamId) throws IOException {
0279       out.close();
0280       synchronized (this) {
0281         completed = true;
0282         notifyAll();
0283       }
0284     }
0285 
0286     @Override
0287     public void onFailure(String streamId, Throwable cause) {
0288       error = cause;
0289       synchronized (this) {
0290         completed = true;
0291         notifyAll();
0292       }
0293     }
0294 
0295     void waitForCompletion(long timeoutMs) {
0296       long now = System.currentTimeMillis();
0297       long deadline = now + timeoutMs;
0298       synchronized (this) {
0299         while (!completed && now < deadline) {
0300           try {
0301             wait(deadline - now);
0302           } catch (InterruptedException ie) {
0303             throw new RuntimeException(ie);
0304           }
0305           now = System.currentTimeMillis();
0306         }
0307       }
0308       assertTrue("Timed out waiting for stream.", completed);
0309       assertNull(error);
0310     }
0311   }
0312 
0313 }