Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network;
0019 
0020 import java.io.*;
0021 import java.nio.ByteBuffer;
0022 import java.util.*;
0023 import java.util.concurrent.ConcurrentHashMap;
0024 import java.util.concurrent.Semaphore;
0025 import java.util.concurrent.TimeUnit;
0026 
0027 import com.google.common.collect.Sets;
0028 import com.google.common.io.Files;
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.junit.AfterClass;
0032 import org.junit.BeforeClass;
0033 import org.junit.Test;
0034 
0035 import static org.junit.Assert.*;
0036 
0037 import org.apache.spark.network.buffer.ManagedBuffer;
0038 import org.apache.spark.network.buffer.NioManagedBuffer;
0039 import org.apache.spark.network.client.*;
0040 import org.apache.spark.network.server.*;
0041 import org.apache.spark.network.util.JavaUtils;
0042 import org.apache.spark.network.util.MapConfigProvider;
0043 import org.apache.spark.network.util.TransportConf;
0044 
0045 public class RpcIntegrationSuite {
0046   static TransportConf conf;
0047   static TransportContext context;
0048   static TransportServer server;
0049   static TransportClientFactory clientFactory;
0050   static RpcHandler rpcHandler;
0051   static List<String> oneWayMsgs;
0052   static StreamTestHelper testData;
0053 
0054   static ConcurrentHashMap<String, VerifyingStreamCallback> streamCallbacks =
0055       new ConcurrentHashMap<>();
0056 
0057   @BeforeClass
0058   public static void setUp() throws Exception {
0059     conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0060     testData = new StreamTestHelper();
0061     rpcHandler = new RpcHandler() {
0062       @Override
0063       public void receive(
0064           TransportClient client,
0065           ByteBuffer message,
0066           RpcResponseCallback callback) {
0067         String msg = JavaUtils.bytesToString(message);
0068         String[] parts = msg.split("/");
0069         if (parts[0].equals("hello")) {
0070           callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
0071         } else if (parts[0].equals("return error")) {
0072           callback.onFailure(new RuntimeException("Returned: " + parts[1]));
0073         } else if (parts[0].equals("throw error")) {
0074           throw new RuntimeException("Thrown: " + parts[1]);
0075         }
0076       }
0077 
0078       @Override
0079       public StreamCallbackWithID receiveStream(
0080           TransportClient client,
0081           ByteBuffer messageHeader,
0082           RpcResponseCallback callback) {
0083         return receiveStreamHelper(JavaUtils.bytesToString(messageHeader));
0084       }
0085 
0086       @Override
0087       public void receive(TransportClient client, ByteBuffer message) {
0088         oneWayMsgs.add(JavaUtils.bytesToString(message));
0089       }
0090 
0091       @Override
0092       public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
0093     };
0094     context = new TransportContext(conf, rpcHandler);
0095     server = context.createServer();
0096     clientFactory = context.createClientFactory();
0097     oneWayMsgs = new ArrayList<>();
0098   }
0099 
0100   private static StreamCallbackWithID receiveStreamHelper(String msg) {
0101     try {
0102       if (msg.startsWith("fail/")) {
0103         String[] parts = msg.split("/");
0104         switch (parts[1]) {
0105           case "exception-ondata":
0106             return new StreamCallbackWithID() {
0107               @Override
0108               public void onData(String streamId, ByteBuffer buf) throws IOException {
0109                 throw new IOException("failed to read stream data!");
0110               }
0111 
0112               @Override
0113               public void onComplete(String streamId) throws IOException {
0114               }
0115 
0116               @Override
0117               public void onFailure(String streamId, Throwable cause) throws IOException {
0118               }
0119 
0120               @Override
0121               public String getID() {
0122                 return msg;
0123               }
0124             };
0125           case "exception-oncomplete":
0126             return new StreamCallbackWithID() {
0127               @Override
0128               public void onData(String streamId, ByteBuffer buf) throws IOException {
0129               }
0130 
0131               @Override
0132               public void onComplete(String streamId) throws IOException {
0133                 throw new IOException("exception in onComplete");
0134               }
0135 
0136               @Override
0137               public void onFailure(String streamId, Throwable cause) throws IOException {
0138               }
0139 
0140               @Override
0141               public String getID() {
0142                 return msg;
0143               }
0144             };
0145           case "null":
0146             return null;
0147           default:
0148             throw new IllegalArgumentException("unexpected msg: " + msg);
0149         }
0150       } else {
0151         VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg);
0152         streamCallbacks.put(msg, streamCallback);
0153         return streamCallback;
0154       }
0155     } catch (IOException e) {
0156       throw new RuntimeException(e);
0157     }
0158   }
0159 
0160   @AfterClass
0161   public static void tearDown() {
0162     server.close();
0163     clientFactory.close();
0164     context.close();
0165     testData.cleanup();
0166   }
0167 
0168   static class RpcResult {
0169     public Set<String> successMessages;
0170     public Set<String> errorMessages;
0171   }
0172 
0173   private RpcResult sendRPC(String ... commands) throws Exception {
0174     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0175     final Semaphore sem = new Semaphore(0);
0176 
0177     final RpcResult res = new RpcResult();
0178     res.successMessages = Collections.synchronizedSet(new HashSet<>());
0179     res.errorMessages = Collections.synchronizedSet(new HashSet<>());
0180 
0181     RpcResponseCallback callback = new RpcResponseCallback() {
0182       @Override
0183       public void onSuccess(ByteBuffer message) {
0184         String response = JavaUtils.bytesToString(message);
0185         res.successMessages.add(response);
0186         sem.release();
0187       }
0188 
0189       @Override
0190       public void onFailure(Throwable e) {
0191         res.errorMessages.add(e.getMessage());
0192         sem.release();
0193       }
0194     };
0195 
0196     for (String command : commands) {
0197       client.sendRpc(JavaUtils.stringToBytes(command), callback);
0198     }
0199 
0200     if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
0201       fail("Timeout getting response from the server");
0202     }
0203     client.close();
0204     return res;
0205   }
0206 
0207   private RpcResult sendRpcWithStream(String... streams) throws Exception {
0208     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0209     final Semaphore sem = new Semaphore(0);
0210     RpcResult res = new RpcResult();
0211     res.successMessages = Collections.synchronizedSet(new HashSet<>());
0212     res.errorMessages = Collections.synchronizedSet(new HashSet<>());
0213 
0214     for (String stream : streams) {
0215       int idx = stream.lastIndexOf('/');
0216       ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream));
0217       String streamName = (idx == -1) ? stream : stream.substring(idx + 1);
0218       ManagedBuffer data = testData.openStream(conf, streamName);
0219       client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem));
0220     }
0221 
0222     if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) {
0223       fail("Timeout getting response from the server");
0224     }
0225     streamCallbacks.values().forEach(streamCallback -> {
0226       try {
0227         streamCallback.verify();
0228       } catch (IOException e) {
0229         throw new RuntimeException(e);
0230       }
0231     });
0232     client.close();
0233     return res;
0234   }
0235 
0236   private static class RpcStreamCallback implements RpcResponseCallback {
0237     final String streamId;
0238     final RpcResult res;
0239     final Semaphore sem;
0240 
0241     RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) {
0242       this.streamId = streamId;
0243       this.res = res;
0244       this.sem = sem;
0245     }
0246 
0247     @Override
0248     public void onSuccess(ByteBuffer message) {
0249       res.successMessages.add(streamId);
0250       sem.release();
0251     }
0252 
0253     @Override
0254     public void onFailure(Throwable e) {
0255       res.errorMessages.add(e.getMessage());
0256       sem.release();
0257     }
0258   }
0259 
0260   @Test
0261   public void singleRPC() throws Exception {
0262     RpcResult res = sendRPC("hello/Aaron");
0263     assertEquals(Sets.newHashSet("Hello, Aaron!"), res.successMessages);
0264     assertTrue(res.errorMessages.isEmpty());
0265   }
0266 
0267   @Test
0268   public void doubleRPC() throws Exception {
0269     RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
0270     assertEquals(Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"), res.successMessages);
0271     assertTrue(res.errorMessages.isEmpty());
0272   }
0273 
0274   @Test
0275   public void returnErrorRPC() throws Exception {
0276     RpcResult res = sendRPC("return error/OK");
0277     assertTrue(res.successMessages.isEmpty());
0278     assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
0279   }
0280 
0281   @Test
0282   public void throwErrorRPC() throws Exception {
0283     RpcResult res = sendRPC("throw error/uh-oh");
0284     assertTrue(res.successMessages.isEmpty());
0285     assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
0286   }
0287 
0288   @Test
0289   public void doubleTrouble() throws Exception {
0290     RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
0291     assertTrue(res.successMessages.isEmpty());
0292     assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
0293   }
0294 
0295   @Test
0296   public void sendSuccessAndFailure() throws Exception {
0297     RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
0298     assertEquals(Sets.newHashSet("Hello, Bob!", "Hello, Builder!"), res.successMessages);
0299     assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
0300   }
0301 
0302   @Test
0303   public void sendOneWayMessage() throws Exception {
0304     final String message = "no reply";
0305     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0306     try {
0307       client.send(JavaUtils.stringToBytes(message));
0308       assertEquals(0, client.getHandler().numOutstandingRequests());
0309 
0310       // Make sure the message arrives.
0311       long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
0312       while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
0313         TimeUnit.MILLISECONDS.sleep(10);
0314       }
0315 
0316       assertEquals(1, oneWayMsgs.size());
0317       assertEquals(message, oneWayMsgs.get(0));
0318     } finally {
0319       client.close();
0320     }
0321   }
0322 
0323   @Test
0324   public void sendRpcWithStreamOneAtATime() throws Exception {
0325     for (String stream : StreamTestHelper.STREAMS) {
0326       RpcResult res = sendRpcWithStream(stream);
0327       assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty());
0328       assertEquals(Sets.newHashSet(stream), res.successMessages);
0329     }
0330   }
0331 
0332   @Test
0333   public void sendRpcWithStreamConcurrently() throws Exception {
0334     String[] streams = new String[10];
0335     for (int i = 0; i < 10; i++) {
0336       streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length];
0337     }
0338     RpcResult res = sendRpcWithStream(streams);
0339     assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages);
0340     assertTrue(res.errorMessages.isEmpty());
0341   }
0342 
0343   @Test
0344   public void sendRpcWithStreamFailures() throws Exception {
0345     // when there is a failure reading stream data, we don't try to keep the channel usable,
0346     // just send back a decent error msg.
0347     RpcResult exceptionInCallbackResult =
0348         sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer");
0349     assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream");
0350 
0351     RpcResult nullStreamHandler =
0352         sendRpcWithStream("fail/null/smallBuffer", "smallBuffer");
0353     assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream");
0354 
0355     // OTOH, if there is a failure during onComplete, the channel should still be fine
0356     RpcResult exceptionInOnComplete =
0357         sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer");
0358     assertErrorsContain(exceptionInOnComplete.errorMessages,
0359         Sets.newHashSet("Failure post-processing"));
0360     assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages);
0361   }
0362 
0363   private void assertErrorsContain(Set<String> errors, Set<String> contains) {
0364     assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " +
0365         errors, contains.size(), errors.size());
0366 
0367     Pair<Set<String>, Set<String>> r = checkErrorsContain(errors, contains);
0368     assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors,
0369         r.getRight().isEmpty());
0370 
0371     assertTrue(r.getLeft().isEmpty());
0372   }
0373 
0374   private void assertErrorAndClosed(RpcResult result, String expectedError) {
0375     assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty());
0376     Set<String> errors = result.errorMessages;
0377     assertEquals("Expected 2 errors, got " + errors.size() + "errors: " +
0378         errors, 2, errors.size());
0379 
0380     // We expect 1 additional error due to closed connection and here are possible keywords in the
0381     // error message.
0382     Set<String> possibleClosedErrors = Sets.newHashSet(
0383         "closed",
0384         "Connection reset",
0385         "java.nio.channels.ClosedChannelException",
0386         "java.io.IOException: Broken pipe"
0387     );
0388     Set<String> containsAndClosed = Sets.newHashSet(expectedError);
0389     containsAndClosed.addAll(possibleClosedErrors);
0390 
0391     Pair<Set<String>, Set<String>> r = checkErrorsContain(errors, containsAndClosed);
0392 
0393     assertTrue("Got a non-empty set " + r.getLeft(), r.getLeft().isEmpty());
0394 
0395     Set<String> errorsNotFound = r.getRight();
0396     assertEquals(
0397         "The size of " + errorsNotFound + " was not " + (possibleClosedErrors.size() - 1),
0398         possibleClosedErrors.size() - 1,
0399         errorsNotFound.size());
0400     for (String err: errorsNotFound) {
0401       assertTrue("Found a wrong error " + err, containsAndClosed.contains(err));
0402     }
0403   }
0404 
0405   private Pair<Set<String>, Set<String>> checkErrorsContain(
0406       Set<String> errors,
0407       Set<String> contains) {
0408     Set<String> remainingErrors = Sets.newHashSet(errors);
0409     Set<String> notFound = Sets.newHashSet();
0410     for (String contain : contains) {
0411       Iterator<String> it = remainingErrors.iterator();
0412       boolean foundMatch = false;
0413       while (it.hasNext()) {
0414         if (it.next().contains(contain)) {
0415           it.remove();
0416           foundMatch = true;
0417           break;
0418         }
0419       }
0420       if (!foundMatch) {
0421         notFound.add(contain);
0422       }
0423     }
0424     return new ImmutablePair<>(remainingErrors, notFound);
0425   }
0426 
0427   private static class VerifyingStreamCallback implements StreamCallbackWithID {
0428     final String streamId;
0429     final StreamSuite.TestCallback helper;
0430     final OutputStream out;
0431     final File outFile;
0432 
0433     VerifyingStreamCallback(String streamId) throws IOException {
0434       if (streamId.equals("file")) {
0435         outFile = File.createTempFile("data", ".tmp", testData.tempDir);
0436         out = new FileOutputStream(outFile);
0437       } else {
0438         out = new ByteArrayOutputStream();
0439         outFile = null;
0440       }
0441       this.streamId = streamId;
0442       helper = new StreamSuite.TestCallback(out);
0443     }
0444 
0445     void verify() throws IOException {
0446       if (streamId.equals("file")) {
0447         assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile));
0448       } else {
0449         byte[] result = ((ByteArrayOutputStream)out).toByteArray();
0450         ByteBuffer srcBuffer = testData.srcBuffer(streamId);
0451         ByteBuffer base;
0452         synchronized (srcBuffer) {
0453           base = srcBuffer.duplicate();
0454         }
0455         byte[] expected = new byte[base.remaining()];
0456         base.get(expected);
0457         assertEquals(expected.length, result.length);
0458         assertTrue("buffers don't match", Arrays.equals(expected, result));
0459       }
0460     }
0461 
0462     @Override
0463     public void onData(String streamId, ByteBuffer buf) throws IOException {
0464       helper.onData(streamId, buf);
0465     }
0466 
0467     @Override
0468     public void onComplete(String streamId) throws IOException {
0469       helper.onComplete(streamId);
0470     }
0471 
0472     @Override
0473     public void onFailure(String streamId, Throwable cause) throws IOException {
0474       helper.onFailure(streamId, cause);
0475     }
0476 
0477     @Override
0478     public String getID() {
0479       return streamId;
0480     }
0481   }
0482 }