0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network;
0019
0020 import java.io.IOException;
0021 import java.util.Collections;
0022 import java.util.HashMap;
0023 import java.util.HashSet;
0024 import java.util.Map;
0025 import java.util.NoSuchElementException;
0026 import java.util.Set;
0027 import java.util.concurrent.atomic.AtomicInteger;
0028
0029 import org.junit.After;
0030 import org.junit.Assert;
0031 import org.junit.Before;
0032 import org.junit.Test;
0033
0034 import static org.junit.Assert.assertFalse;
0035 import static org.junit.Assert.assertNotSame;
0036 import static org.junit.Assert.assertTrue;
0037
0038 import org.apache.spark.network.client.TransportClient;
0039 import org.apache.spark.network.client.TransportClientFactory;
0040 import org.apache.spark.network.server.NoOpRpcHandler;
0041 import org.apache.spark.network.server.RpcHandler;
0042 import org.apache.spark.network.server.TransportServer;
0043 import org.apache.spark.network.util.ConfigProvider;
0044 import org.apache.spark.network.util.MapConfigProvider;
0045 import org.apache.spark.network.util.JavaUtils;
0046 import org.apache.spark.network.util.TransportConf;
0047
0048 public class TransportClientFactorySuite {
0049 private TransportConf conf;
0050 private TransportContext context;
0051 private TransportServer server1;
0052 private TransportServer server2;
0053
0054 @Before
0055 public void setUp() {
0056 conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0057 RpcHandler rpcHandler = new NoOpRpcHandler();
0058 context = new TransportContext(conf, rpcHandler);
0059 server1 = context.createServer();
0060 server2 = context.createServer();
0061 }
0062
0063 @After
0064 public void tearDown() {
0065 JavaUtils.closeQuietly(server1);
0066 JavaUtils.closeQuietly(server2);
0067 JavaUtils.closeQuietly(context);
0068 }
0069
0070
0071
0072
0073
0074
0075
0076 private void testClientReuse(int maxConnections, boolean concurrent)
0077 throws IOException, InterruptedException {
0078
0079 Map<String, String> configMap = new HashMap<>();
0080 configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections));
0081 TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
0082
0083 RpcHandler rpcHandler = new NoOpRpcHandler();
0084 try (TransportContext context = new TransportContext(conf, rpcHandler)) {
0085 TransportClientFactory factory = context.createClientFactory();
0086 Set<TransportClient> clients = Collections.synchronizedSet(
0087 new HashSet<>());
0088
0089 AtomicInteger failed = new AtomicInteger();
0090 Thread[] attempts = new Thread[maxConnections * 10];
0091
0092
0093 for (int i = 0; i < attempts.length; i++) {
0094 attempts[i] = new Thread(() -> {
0095 try {
0096 TransportClient client =
0097 factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0098 assertTrue(client.isActive());
0099 clients.add(client);
0100 } catch (IOException e) {
0101 failed.incrementAndGet();
0102 } catch (InterruptedException e) {
0103 throw new RuntimeException(e);
0104 }
0105 });
0106
0107 if (concurrent) {
0108 attempts[i].start();
0109 } else {
0110 attempts[i].run();
0111 }
0112 }
0113
0114
0115 for (Thread attempt : attempts) {
0116 attempt.join();
0117 }
0118
0119 Assert.assertEquals(0, failed.get());
0120 Assert.assertTrue(clients.size() <= maxConnections);
0121
0122 for (TransportClient client : clients) {
0123 client.close();
0124 }
0125
0126 factory.close();
0127 }
0128 }
0129
0130 @Test
0131 public void reuseClientsUpToConfigVariable() throws Exception {
0132 testClientReuse(1, false);
0133 testClientReuse(2, false);
0134 testClientReuse(3, false);
0135 testClientReuse(4, false);
0136 }
0137
0138 @Test
0139 public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
0140 testClientReuse(1, true);
0141 testClientReuse(2, true);
0142 testClientReuse(3, true);
0143 testClientReuse(4, true);
0144 }
0145
0146 @Test
0147 public void returnDifferentClientsForDifferentServers() throws IOException, InterruptedException {
0148 TransportClientFactory factory = context.createClientFactory();
0149 TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0150 TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
0151 assertTrue(c1.isActive());
0152 assertTrue(c2.isActive());
0153 assertNotSame(c1, c2);
0154 factory.close();
0155 }
0156
0157 @Test
0158 public void neverReturnInactiveClients() throws IOException, InterruptedException {
0159 TransportClientFactory factory = context.createClientFactory();
0160 TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0161 c1.close();
0162
0163 long start = System.currentTimeMillis();
0164 while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) {
0165 Thread.sleep(10);
0166 }
0167 assertFalse(c1.isActive());
0168
0169 TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0170 assertNotSame(c1, c2);
0171 assertTrue(c2.isActive());
0172 factory.close();
0173 }
0174
0175 @Test
0176 public void closeBlockClientsWithFactory() throws IOException, InterruptedException {
0177 TransportClientFactory factory = context.createClientFactory();
0178 TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0179 TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
0180 assertTrue(c1.isActive());
0181 assertTrue(c2.isActive());
0182 factory.close();
0183 assertFalse(c1.isActive());
0184 assertFalse(c2.isActive());
0185 }
0186
0187 @Test
0188 public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException {
0189 TransportConf conf = new TransportConf("shuffle", new ConfigProvider() {
0190
0191 @Override
0192 public String get(String name) {
0193 if ("spark.shuffle.io.connectionTimeout".equals(name)) {
0194
0195 return "1s";
0196 }
0197 String value = System.getProperty(name);
0198 if (value == null) {
0199 throw new NoSuchElementException(name);
0200 }
0201 return value;
0202 }
0203
0204 @Override
0205 public Iterable<Map.Entry<String, String>> getAll() {
0206 throw new UnsupportedOperationException();
0207 }
0208 });
0209 try (TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
0210 TransportClientFactory factory = context.createClientFactory()) {
0211 TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0212 assertTrue(c1.isActive());
0213 long expiredTime = System.currentTimeMillis() + 10000;
0214 while (c1.isActive() && System.currentTimeMillis() < expiredTime) {
0215 Thread.sleep(10);
0216 }
0217 assertFalse(c1.isActive());
0218 }
0219 }
0220
0221 @Test(expected = IOException.class)
0222 public void closeFactoryBeforeCreateClient() throws IOException, InterruptedException {
0223 TransportClientFactory factory = context.createClientFactory();
0224 factory.close();
0225 factory.createClient(TestUtils.getLocalHost(), server1.getPort());
0226 }
0227 }