Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.launcher;
0019 
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.io.ObjectInputStream;
0023 import java.net.InetAddress;
0024 import java.net.Socket;
0025 import java.net.SocketException;
0026 import java.time.Duration;
0027 import java.util.Arrays;
0028 import java.util.List;
0029 import java.util.concurrent.BlockingQueue;
0030 import java.util.concurrent.LinkedBlockingQueue;
0031 import java.util.concurrent.Semaphore;
0032 import java.util.concurrent.TimeUnit;
0033 import java.util.concurrent.atomic.AtomicBoolean;
0034 
0035 import org.junit.Test;
0036 import static org.junit.Assert.*;
0037 
0038 import static org.apache.spark.launcher.LauncherProtocol.*;
0039 
0040 public class LauncherServerSuite extends BaseSuite {
0041 
0042   @Test
0043   public void testLauncherServerReuse() throws Exception {
0044     LauncherServer server1 = LauncherServer.getOrCreateServer();
0045     ChildProcAppHandle handle = new ChildProcAppHandle(server1);
0046     handle.kill();
0047 
0048     LauncherServer server2 = LauncherServer.getOrCreateServer();
0049     try {
0050       assertNotSame(server1, server2);
0051     } finally {
0052       server2.unref();
0053     }
0054   }
0055 
0056   @Test
0057   public void testCommunication() throws Exception {
0058     LauncherServer server = LauncherServer.getOrCreateServer();
0059     ChildProcAppHandle handle = new ChildProcAppHandle(server);
0060     String secret = server.registerHandle(handle);
0061 
0062     TestClient client = null;
0063     try {
0064       Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0065 
0066       final Semaphore semaphore = new Semaphore(0);
0067       handle.addListener(new SparkAppHandle.Listener() {
0068         @Override
0069         public void stateChanged(SparkAppHandle handle) {
0070           semaphore.release();
0071         }
0072         @Override
0073         public void infoChanged(SparkAppHandle handle) {
0074           semaphore.release();
0075         }
0076       });
0077 
0078       client = new TestClient(s);
0079       client.send(new Hello(secret, "1.4.0"));
0080       assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0081 
0082       // Make sure the server matched the client to the handle.
0083       assertNotNull(handle.getConnection());
0084 
0085       client.send(new SetAppId("app-id"));
0086       assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0087       assertEquals("app-id", handle.getAppId());
0088 
0089       client.send(new SetState(SparkAppHandle.State.RUNNING));
0090       assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS));
0091       assertEquals(SparkAppHandle.State.RUNNING, handle.getState());
0092 
0093       handle.stop();
0094       Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS);
0095       assertTrue(stopMsg instanceof Stop);
0096     } finally {
0097       close(client);
0098       handle.kill();
0099       client.clientThread.join();
0100     }
0101   }
0102 
0103   @Test
0104   public void testTimeout() throws Exception {
0105     LauncherServer server = LauncherServer.getOrCreateServer();
0106     ChildProcAppHandle handle = new ChildProcAppHandle(server);
0107     String secret = server.registerHandle(handle);
0108 
0109     TestClient client = null;
0110     try {
0111       // LauncherServer will immediately close the server-side socket when the timeout is set
0112       // to 0.
0113       SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, "0");
0114 
0115       Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0116       client = new TestClient(s);
0117       waitForError(client, secret);
0118     } finally {
0119       SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
0120       handle.kill();
0121       close(client);
0122     }
0123   }
0124 
0125   @Test
0126   public void testSparkSubmitVmShutsDown() throws Exception {
0127     LauncherServer server = LauncherServer.getOrCreateServer();
0128     ChildProcAppHandle handle = new ChildProcAppHandle(server);
0129     String secret = server.registerHandle(handle);
0130 
0131     TestClient client = null;
0132     final Semaphore semaphore = new Semaphore(0);
0133     try {
0134       Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0135       handle.addListener(new SparkAppHandle.Listener() {
0136         public void stateChanged(SparkAppHandle handle) {
0137           semaphore.release();
0138         }
0139         public void infoChanged(SparkAppHandle handle) {
0140           semaphore.release();
0141         }
0142       });
0143       client = new TestClient(s);
0144       client.send(new Hello(secret, "1.4.0"));
0145       assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0146       // Make sure the server matched the client to the handle.
0147       assertNotNull(handle.getConnection());
0148       client.close();
0149       handle.dispose();
0150       assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
0151       assertEquals(SparkAppHandle.State.LOST, handle.getState());
0152     } finally {
0153       handle.kill();
0154       close(client);
0155       client.clientThread.join();
0156     }
0157   }
0158 
0159   @Test
0160   public void testStreamFiltering() throws Exception {
0161     LauncherServer server = LauncherServer.getOrCreateServer();
0162     ChildProcAppHandle handle = new ChildProcAppHandle(server);
0163     String secret = server.registerHandle(handle);
0164 
0165     TestClient client = null;
0166     try {
0167       Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0168 
0169       client = new TestClient(s);
0170 
0171       try {
0172         client.send(new EvilPayload());
0173       } catch (SocketException se) {
0174         // SPARK-21522: this can happen if the server closes the socket before the full message has
0175         // been written, so it's expected. It may cause false positives though (socket errors
0176         // happening for other reasons).
0177       }
0178 
0179       waitForError(client, secret);
0180       assertEquals(0, EvilPayload.EVIL_BIT);
0181     } finally {
0182       handle.kill();
0183       close(client);
0184       client.clientThread.join();
0185     }
0186   }
0187 
0188   @Test
0189   public void testAppHandleDisconnect() throws Exception {
0190     LauncherServer server = LauncherServer.getOrCreateServer();
0191     ChildProcAppHandle handle = new ChildProcAppHandle(server);
0192     String secret = server.registerHandle(handle);
0193 
0194     TestClient client = null;
0195     try {
0196       Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
0197       client = new TestClient(s);
0198       client.send(new Hello(secret, "1.4.0"));
0199       client.send(new SetAppId("someId"));
0200 
0201       // Wait until we know the server has received the messages and matched the handle to the
0202       // connection before disconnecting.
0203       eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
0204         assertEquals("someId", handle.getAppId());
0205       });
0206 
0207       handle.disconnect();
0208       waitForError(client, secret);
0209     } finally {
0210       handle.kill();
0211       close(client);
0212       client.clientThread.join();
0213     }
0214   }
0215 
0216   private void close(Closeable c) {
0217     if (c != null) {
0218       try {
0219         c.close();
0220       } catch (Exception e) {
0221         // no-op.
0222       }
0223     }
0224   }
0225 
0226   /**
0227    * Try a few times to get a client-side error, since the client-side socket may not reflect the
0228    * server-side close immediately.
0229    */
0230   private void waitForError(TestClient client, String secret) throws Exception {
0231     final AtomicBoolean helloSent = new AtomicBoolean();
0232     eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
0233       try {
0234         if (!helloSent.get()) {
0235           client.send(new Hello(secret, "1.4.0"));
0236           helloSent.set(true);
0237         } else {
0238           client.send(new SetAppId("appId"));
0239         }
0240         fail("Expected error but message went through.");
0241       } catch (IllegalStateException | IOException e) {
0242         // Expected.
0243       }
0244     });
0245   }
0246 
0247   private static class TestClient extends LauncherConnection {
0248 
0249     final BlockingQueue<Message> inbound;
0250     final Thread clientThread;
0251 
0252     TestClient(Socket s) throws IOException {
0253       super(s);
0254       this.inbound = new LinkedBlockingQueue<>();
0255       this.clientThread = new Thread(this);
0256       clientThread.setName("TestClient");
0257       clientThread.setDaemon(true);
0258       clientThread.start();
0259     }
0260 
0261     @Override
0262     protected void handle(Message msg) throws IOException {
0263       inbound.offer(msg);
0264     }
0265 
0266   }
0267 
0268   private static class EvilPayload extends LauncherProtocol.Message {
0269 
0270     static int EVIL_BIT = 0;
0271 
0272     // This field should cause the launcher server to throw an error and not deserialize the
0273     // message.
0274     private List<String> notAllowedField = Arrays.asList("disallowed");
0275 
0276     private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
0277       stream.defaultReadObject();
0278       EVIL_BIT = 1;
0279     }
0280 
0281   }
0282 
0283 }