0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network;
0019
0020 import java.io.*;
0021 import java.nio.ByteBuffer;
0022 import java.util.*;
0023 import java.util.concurrent.ConcurrentHashMap;
0024 import java.util.concurrent.Semaphore;
0025 import java.util.concurrent.TimeUnit;
0026
0027 import com.google.common.collect.Sets;
0028 import com.google.common.io.Files;
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.junit.AfterClass;
0032 import org.junit.BeforeClass;
0033 import org.junit.Test;
0034
0035 import static org.junit.Assert.*;
0036
0037 import org.apache.spark.network.buffer.ManagedBuffer;
0038 import org.apache.spark.network.buffer.NioManagedBuffer;
0039 import org.apache.spark.network.client.*;
0040 import org.apache.spark.network.server.*;
0041 import org.apache.spark.network.util.JavaUtils;
0042 import org.apache.spark.network.util.MapConfigProvider;
0043 import org.apache.spark.network.util.TransportConf;
0044
0045 public class RpcIntegrationSuite {
0046 static TransportConf conf;
0047 static TransportContext context;
0048 static TransportServer server;
0049 static TransportClientFactory clientFactory;
0050 static RpcHandler rpcHandler;
0051 static List<String> oneWayMsgs;
0052 static StreamTestHelper testData;
0053
0054 static ConcurrentHashMap<String, VerifyingStreamCallback> streamCallbacks =
0055 new ConcurrentHashMap<>();
0056
0057 @BeforeClass
0058 public static void setUp() throws Exception {
0059 conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0060 testData = new StreamTestHelper();
0061 rpcHandler = new RpcHandler() {
0062 @Override
0063 public void receive(
0064 TransportClient client,
0065 ByteBuffer message,
0066 RpcResponseCallback callback) {
0067 String msg = JavaUtils.bytesToString(message);
0068 String[] parts = msg.split("/");
0069 if (parts[0].equals("hello")) {
0070 callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
0071 } else if (parts[0].equals("return error")) {
0072 callback.onFailure(new RuntimeException("Returned: " + parts[1]));
0073 } else if (parts[0].equals("throw error")) {
0074 throw new RuntimeException("Thrown: " + parts[1]);
0075 }
0076 }
0077
0078 @Override
0079 public StreamCallbackWithID receiveStream(
0080 TransportClient client,
0081 ByteBuffer messageHeader,
0082 RpcResponseCallback callback) {
0083 return receiveStreamHelper(JavaUtils.bytesToString(messageHeader));
0084 }
0085
0086 @Override
0087 public void receive(TransportClient client, ByteBuffer message) {
0088 oneWayMsgs.add(JavaUtils.bytesToString(message));
0089 }
0090
0091 @Override
0092 public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
0093 };
0094 context = new TransportContext(conf, rpcHandler);
0095 server = context.createServer();
0096 clientFactory = context.createClientFactory();
0097 oneWayMsgs = new ArrayList<>();
0098 }
0099
0100 private static StreamCallbackWithID receiveStreamHelper(String msg) {
0101 try {
0102 if (msg.startsWith("fail/")) {
0103 String[] parts = msg.split("/");
0104 switch (parts[1]) {
0105 case "exception-ondata":
0106 return new StreamCallbackWithID() {
0107 @Override
0108 public void onData(String streamId, ByteBuffer buf) throws IOException {
0109 throw new IOException("failed to read stream data!");
0110 }
0111
0112 @Override
0113 public void onComplete(String streamId) throws IOException {
0114 }
0115
0116 @Override
0117 public void onFailure(String streamId, Throwable cause) throws IOException {
0118 }
0119
0120 @Override
0121 public String getID() {
0122 return msg;
0123 }
0124 };
0125 case "exception-oncomplete":
0126 return new StreamCallbackWithID() {
0127 @Override
0128 public void onData(String streamId, ByteBuffer buf) throws IOException {
0129 }
0130
0131 @Override
0132 public void onComplete(String streamId) throws IOException {
0133 throw new IOException("exception in onComplete");
0134 }
0135
0136 @Override
0137 public void onFailure(String streamId, Throwable cause) throws IOException {
0138 }
0139
0140 @Override
0141 public String getID() {
0142 return msg;
0143 }
0144 };
0145 case "null":
0146 return null;
0147 default:
0148 throw new IllegalArgumentException("unexpected msg: " + msg);
0149 }
0150 } else {
0151 VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg);
0152 streamCallbacks.put(msg, streamCallback);
0153 return streamCallback;
0154 }
0155 } catch (IOException e) {
0156 throw new RuntimeException(e);
0157 }
0158 }
0159
0160 @AfterClass
0161 public static void tearDown() {
0162 server.close();
0163 clientFactory.close();
0164 context.close();
0165 testData.cleanup();
0166 }
0167
0168 static class RpcResult {
0169 public Set<String> successMessages;
0170 public Set<String> errorMessages;
0171 }
0172
0173 private RpcResult sendRPC(String ... commands) throws Exception {
0174 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0175 final Semaphore sem = new Semaphore(0);
0176
0177 final RpcResult res = new RpcResult();
0178 res.successMessages = Collections.synchronizedSet(new HashSet<>());
0179 res.errorMessages = Collections.synchronizedSet(new HashSet<>());
0180
0181 RpcResponseCallback callback = new RpcResponseCallback() {
0182 @Override
0183 public void onSuccess(ByteBuffer message) {
0184 String response = JavaUtils.bytesToString(message);
0185 res.successMessages.add(response);
0186 sem.release();
0187 }
0188
0189 @Override
0190 public void onFailure(Throwable e) {
0191 res.errorMessages.add(e.getMessage());
0192 sem.release();
0193 }
0194 };
0195
0196 for (String command : commands) {
0197 client.sendRpc(JavaUtils.stringToBytes(command), callback);
0198 }
0199
0200 if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
0201 fail("Timeout getting response from the server");
0202 }
0203 client.close();
0204 return res;
0205 }
0206
0207 private RpcResult sendRpcWithStream(String... streams) throws Exception {
0208 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0209 final Semaphore sem = new Semaphore(0);
0210 RpcResult res = new RpcResult();
0211 res.successMessages = Collections.synchronizedSet(new HashSet<>());
0212 res.errorMessages = Collections.synchronizedSet(new HashSet<>());
0213
0214 for (String stream : streams) {
0215 int idx = stream.lastIndexOf('/');
0216 ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream));
0217 String streamName = (idx == -1) ? stream : stream.substring(idx + 1);
0218 ManagedBuffer data = testData.openStream(conf, streamName);
0219 client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem));
0220 }
0221
0222 if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) {
0223 fail("Timeout getting response from the server");
0224 }
0225 streamCallbacks.values().forEach(streamCallback -> {
0226 try {
0227 streamCallback.verify();
0228 } catch (IOException e) {
0229 throw new RuntimeException(e);
0230 }
0231 });
0232 client.close();
0233 return res;
0234 }
0235
0236 private static class RpcStreamCallback implements RpcResponseCallback {
0237 final String streamId;
0238 final RpcResult res;
0239 final Semaphore sem;
0240
0241 RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) {
0242 this.streamId = streamId;
0243 this.res = res;
0244 this.sem = sem;
0245 }
0246
0247 @Override
0248 public void onSuccess(ByteBuffer message) {
0249 res.successMessages.add(streamId);
0250 sem.release();
0251 }
0252
0253 @Override
0254 public void onFailure(Throwable e) {
0255 res.errorMessages.add(e.getMessage());
0256 sem.release();
0257 }
0258 }
0259
0260 @Test
0261 public void singleRPC() throws Exception {
0262 RpcResult res = sendRPC("hello/Aaron");
0263 assertEquals(Sets.newHashSet("Hello, Aaron!"), res.successMessages);
0264 assertTrue(res.errorMessages.isEmpty());
0265 }
0266
0267 @Test
0268 public void doubleRPC() throws Exception {
0269 RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
0270 assertEquals(Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"), res.successMessages);
0271 assertTrue(res.errorMessages.isEmpty());
0272 }
0273
0274 @Test
0275 public void returnErrorRPC() throws Exception {
0276 RpcResult res = sendRPC("return error/OK");
0277 assertTrue(res.successMessages.isEmpty());
0278 assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
0279 }
0280
0281 @Test
0282 public void throwErrorRPC() throws Exception {
0283 RpcResult res = sendRPC("throw error/uh-oh");
0284 assertTrue(res.successMessages.isEmpty());
0285 assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
0286 }
0287
0288 @Test
0289 public void doubleTrouble() throws Exception {
0290 RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
0291 assertTrue(res.successMessages.isEmpty());
0292 assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
0293 }
0294
0295 @Test
0296 public void sendSuccessAndFailure() throws Exception {
0297 RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
0298 assertEquals(Sets.newHashSet("Hello, Bob!", "Hello, Builder!"), res.successMessages);
0299 assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
0300 }
0301
0302 @Test
0303 public void sendOneWayMessage() throws Exception {
0304 final String message = "no reply";
0305 TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
0306 try {
0307 client.send(JavaUtils.stringToBytes(message));
0308 assertEquals(0, client.getHandler().numOutstandingRequests());
0309
0310
0311 long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
0312 while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
0313 TimeUnit.MILLISECONDS.sleep(10);
0314 }
0315
0316 assertEquals(1, oneWayMsgs.size());
0317 assertEquals(message, oneWayMsgs.get(0));
0318 } finally {
0319 client.close();
0320 }
0321 }
0322
0323 @Test
0324 public void sendRpcWithStreamOneAtATime() throws Exception {
0325 for (String stream : StreamTestHelper.STREAMS) {
0326 RpcResult res = sendRpcWithStream(stream);
0327 assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty());
0328 assertEquals(Sets.newHashSet(stream), res.successMessages);
0329 }
0330 }
0331
0332 @Test
0333 public void sendRpcWithStreamConcurrently() throws Exception {
0334 String[] streams = new String[10];
0335 for (int i = 0; i < 10; i++) {
0336 streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length];
0337 }
0338 RpcResult res = sendRpcWithStream(streams);
0339 assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages);
0340 assertTrue(res.errorMessages.isEmpty());
0341 }
0342
0343 @Test
0344 public void sendRpcWithStreamFailures() throws Exception {
0345
0346
0347 RpcResult exceptionInCallbackResult =
0348 sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer");
0349 assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream");
0350
0351 RpcResult nullStreamHandler =
0352 sendRpcWithStream("fail/null/smallBuffer", "smallBuffer");
0353 assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream");
0354
0355
0356 RpcResult exceptionInOnComplete =
0357 sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer");
0358 assertErrorsContain(exceptionInOnComplete.errorMessages,
0359 Sets.newHashSet("Failure post-processing"));
0360 assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages);
0361 }
0362
0363 private void assertErrorsContain(Set<String> errors, Set<String> contains) {
0364 assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " +
0365 errors, contains.size(), errors.size());
0366
0367 Pair<Set<String>, Set<String>> r = checkErrorsContain(errors, contains);
0368 assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors,
0369 r.getRight().isEmpty());
0370
0371 assertTrue(r.getLeft().isEmpty());
0372 }
0373
0374 private void assertErrorAndClosed(RpcResult result, String expectedError) {
0375 assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty());
0376 Set<String> errors = result.errorMessages;
0377 assertEquals("Expected 2 errors, got " + errors.size() + "errors: " +
0378 errors, 2, errors.size());
0379
0380
0381
0382 Set<String> possibleClosedErrors = Sets.newHashSet(
0383 "closed",
0384 "Connection reset",
0385 "java.nio.channels.ClosedChannelException",
0386 "java.io.IOException: Broken pipe"
0387 );
0388 Set<String> containsAndClosed = Sets.newHashSet(expectedError);
0389 containsAndClosed.addAll(possibleClosedErrors);
0390
0391 Pair<Set<String>, Set<String>> r = checkErrorsContain(errors, containsAndClosed);
0392
0393 assertTrue("Got a non-empty set " + r.getLeft(), r.getLeft().isEmpty());
0394
0395 Set<String> errorsNotFound = r.getRight();
0396 assertEquals(
0397 "The size of " + errorsNotFound + " was not " + (possibleClosedErrors.size() - 1),
0398 possibleClosedErrors.size() - 1,
0399 errorsNotFound.size());
0400 for (String err: errorsNotFound) {
0401 assertTrue("Found a wrong error " + err, containsAndClosed.contains(err));
0402 }
0403 }
0404
0405 private Pair<Set<String>, Set<String>> checkErrorsContain(
0406 Set<String> errors,
0407 Set<String> contains) {
0408 Set<String> remainingErrors = Sets.newHashSet(errors);
0409 Set<String> notFound = Sets.newHashSet();
0410 for (String contain : contains) {
0411 Iterator<String> it = remainingErrors.iterator();
0412 boolean foundMatch = false;
0413 while (it.hasNext()) {
0414 if (it.next().contains(contain)) {
0415 it.remove();
0416 foundMatch = true;
0417 break;
0418 }
0419 }
0420 if (!foundMatch) {
0421 notFound.add(contain);
0422 }
0423 }
0424 return new ImmutablePair<>(remainingErrors, notFound);
0425 }
0426
0427 private static class VerifyingStreamCallback implements StreamCallbackWithID {
0428 final String streamId;
0429 final StreamSuite.TestCallback helper;
0430 final OutputStream out;
0431 final File outFile;
0432
0433 VerifyingStreamCallback(String streamId) throws IOException {
0434 if (streamId.equals("file")) {
0435 outFile = File.createTempFile("data", ".tmp", testData.tempDir);
0436 out = new FileOutputStream(outFile);
0437 } else {
0438 out = new ByteArrayOutputStream();
0439 outFile = null;
0440 }
0441 this.streamId = streamId;
0442 helper = new StreamSuite.TestCallback(out);
0443 }
0444
0445 void verify() throws IOException {
0446 if (streamId.equals("file")) {
0447 assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile));
0448 } else {
0449 byte[] result = ((ByteArrayOutputStream)out).toByteArray();
0450 ByteBuffer srcBuffer = testData.srcBuffer(streamId);
0451 ByteBuffer base;
0452 synchronized (srcBuffer) {
0453 base = srcBuffer.duplicate();
0454 }
0455 byte[] expected = new byte[base.remaining()];
0456 base.get(expected);
0457 assertEquals(expected.length, result.length);
0458 assertTrue("buffers don't match", Arrays.equals(expected, result));
0459 }
0460 }
0461
0462 @Override
0463 public void onData(String streamId, ByteBuffer buf) throws IOException {
0464 helper.onData(streamId, buf);
0465 }
0466
0467 @Override
0468 public void onComplete(String streamId) throws IOException {
0469 helper.onComplete(streamId);
0470 }
0471
0472 @Override
0473 public void onFailure(String streamId, Throwable cause) throws IOException {
0474 helper.onFailure(streamId, cause);
0475 }
0476
0477 @Override
0478 public String getID() {
0479 return streamId;
0480 }
0481 }
0482 }