0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.launcher;
0019
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.io.ObjectInputStream;
0023 import java.net.InetAddress;
0024 import java.net.Socket;
0025 import java.net.SocketException;
0026 import java.time.Duration;
0027 import java.util.Arrays;
0028 import java.util.List;
0029 import java.util.concurrent.BlockingQueue;
0030 import java.util.concurrent.LinkedBlockingQueue;
0031 import java.util.concurrent.Semaphore;
0032 import java.util.concurrent.TimeUnit;
0033 import java.util.concurrent.atomic.AtomicBoolean;
0034
0035 import org.junit.Test;
0036 import static org.junit.Assert.*;
0037
0038 import static org.apache.spark.launcher.LauncherProtocol.*;
0039
0040 public class LauncherServerSuite extends BaseSuite {
0041
0042 @Test
0043 public void testLauncherServerReuse() throws Exception {
0044 LauncherServer server1 = LauncherServer.getOrCreateServer();
0045 ChildProcAppHandle handle = new ChildProcAppHandle(server1);
0046 handle.kill();
0047
0048 LauncherServer server2 = LauncherServer.getOrCreateServer();
0049 try {
0050 assertNotSame(server1, server2);
0051 } finally {
0052 server2.unref();
0053 }
0054 }
0055
0056 @Test
0057 public void testCommunication() throws Exception {
0058 LauncherServer server = LauncherServer.getOrCreateServer();
0059 ChildProcAppHandle handle = new ChildProcAppHandle(server);
0060 String secret = server.registerHandle(handle);
0061
0062 TestClient client = null;
0063 try {
0064 Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0065
0066 final Semaphore semaphore = new Semaphore(0);
0067 handle.addListener(new SparkAppHandle.Listener() {
0068 @Override
0069 public void stateChanged(SparkAppHandle handle) {
0070 semaphore.release();
0071 }
0072 @Override
0073 public void infoChanged(SparkAppHandle handle) {
0074 semaphore.release();
0075 }
0076 });
0077
0078 client = new TestClient(s);
0079 client.send(new Hello(secret, "1.4.0"));
0080 assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0081
0082
0083 assertNotNull(handle.getConnection());
0084
0085 client.send(new SetAppId("app-id"));
0086 assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0087 assertEquals("app-id", handle.getAppId());
0088
0089 client.send(new SetState(SparkAppHandle.State.RUNNING));
0090 assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS));
0091 assertEquals(SparkAppHandle.State.RUNNING, handle.getState());
0092
0093 handle.stop();
0094 Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS);
0095 assertTrue(stopMsg instanceof Stop);
0096 } finally {
0097 close(client);
0098 handle.kill();
0099 client.clientThread.join();
0100 }
0101 }
0102
0103 @Test
0104 public void testTimeout() throws Exception {
0105 LauncherServer server = LauncherServer.getOrCreateServer();
0106 ChildProcAppHandle handle = new ChildProcAppHandle(server);
0107 String secret = server.registerHandle(handle);
0108
0109 TestClient client = null;
0110 try {
0111
0112
0113 SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, "0");
0114
0115 Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0116 client = new TestClient(s);
0117 waitForError(client, secret);
0118 } finally {
0119 SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
0120 handle.kill();
0121 close(client);
0122 }
0123 }
0124
0125 @Test
0126 public void testSparkSubmitVmShutsDown() throws Exception {
0127 LauncherServer server = LauncherServer.getOrCreateServer();
0128 ChildProcAppHandle handle = new ChildProcAppHandle(server);
0129 String secret = server.registerHandle(handle);
0130
0131 TestClient client = null;
0132 final Semaphore semaphore = new Semaphore(0);
0133 try {
0134 Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0135 handle.addListener(new SparkAppHandle.Listener() {
0136 public void stateChanged(SparkAppHandle handle) {
0137 semaphore.release();
0138 }
0139 public void infoChanged(SparkAppHandle handle) {
0140 semaphore.release();
0141 }
0142 });
0143 client = new TestClient(s);
0144 client.send(new Hello(secret, "1.4.0"));
0145 assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0146
0147 assertNotNull(handle.getConnection());
0148 client.close();
0149 handle.dispose();
0150 assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0151 assertEquals(SparkAppHandle.State.LOST, handle.getState());
0152 } finally {
0153 handle.kill();
0154 close(client);
0155 client.clientThread.join();
0156 }
0157 }
0158
0159 @Test
0160 public void testStreamFiltering() throws Exception {
0161 LauncherServer server = LauncherServer.getOrCreateServer();
0162 ChildProcAppHandle handle = new ChildProcAppHandle(server);
0163 String secret = server.registerHandle(handle);
0164
0165 TestClient client = null;
0166 try {
0167 Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0168
0169 client = new TestClient(s);
0170
0171 try {
0172 client.send(new EvilPayload());
0173 } catch (SocketException se) {
0174
0175
0176
0177 }
0178
0179 waitForError(client, secret);
0180 assertEquals(0, EvilPayload.EVIL_BIT);
0181 } finally {
0182 handle.kill();
0183 close(client);
0184 client.clientThread.join();
0185 }
0186 }
0187
0188 @Test
0189 public void testAppHandleDisconnect() throws Exception {
0190 LauncherServer server = LauncherServer.getOrCreateServer();
0191 ChildProcAppHandle handle = new ChildProcAppHandle(server);
0192 String secret = server.registerHandle(handle);
0193
0194 TestClient client = null;
0195 try {
0196 Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0197 client = new TestClient(s);
0198 client.send(new Hello(secret, "1.4.0"));
0199 client.send(new SetAppId("someId"));
0200
0201
0202
0203 eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
0204 assertEquals("someId", handle.getAppId());
0205 });
0206
0207 handle.disconnect();
0208 waitForError(client, secret);
0209 } finally {
0210 handle.kill();
0211 close(client);
0212 client.clientThread.join();
0213 }
0214 }
0215
0216 private void close(Closeable c) {
0217 if (c != null) {
0218 try {
0219 c.close();
0220 } catch (Exception e) {
0221
0222 }
0223 }
0224 }
0225
0226
0227
0228
0229
0230 private void waitForError(TestClient client, String secret) throws Exception {
0231 final AtomicBoolean helloSent = new AtomicBoolean();
0232 eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
0233 try {
0234 if (!helloSent.get()) {
0235 client.send(new Hello(secret, "1.4.0"));
0236 helloSent.set(true);
0237 } else {
0238 client.send(new SetAppId("appId"));
0239 }
0240 fail("Expected error but message went through.");
0241 } catch (IllegalStateException | IOException e) {
0242
0243 }
0244 });
0245 }
0246
0247 private static class TestClient extends LauncherConnection {
0248
0249 final BlockingQueue<Message> inbound;
0250 final Thread clientThread;
0251
0252 TestClient(Socket s) throws IOException {
0253 super(s);
0254 this.inbound = new LinkedBlockingQueue<>();
0255 this.clientThread = new Thread(this);
0256 clientThread.setName("TestClient");
0257 clientThread.setDaemon(true);
0258 clientThread.start();
0259 }
0260
0261 @Override
0262 protected void handle(Message msg) throws IOException {
0263 inbound.offer(msg);
0264 }
0265
0266 }
0267
0268 private static class EvilPayload extends LauncherProtocol.Message {
0269
0270 static int EVIL_BIT = 0;
0271
0272
0273
0274 private List<String> notAllowedField = Arrays.asList("disallowed");
0275
0276 private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
0277 stream.defaultReadObject();
0278 EVIL_BIT = 1;
0279 }
0280
0281 }
0282
0283 }