0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.launcher;
0019
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.net.InetAddress;
0023 import java.net.InetSocketAddress;
0024 import java.net.ServerSocket;
0025 import java.net.Socket;
0026 import java.security.SecureRandom;
0027 import java.util.ArrayList;
0028 import java.util.List;
0029 import java.util.Map;
0030 import java.util.Timer;
0031 import java.util.TimerTask;
0032 import java.util.concurrent.ConcurrentHashMap;
0033 import java.util.concurrent.ConcurrentMap;
0034 import java.util.concurrent.ThreadFactory;
0035 import java.util.concurrent.atomic.AtomicLong;
0036 import java.util.logging.Level;
0037 import java.util.logging.Logger;
0038
0039 import static org.apache.spark.launcher.LauncherProtocol.*;
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081 class LauncherServer implements Closeable {
0082
0083 private static final Logger LOG = Logger.getLogger(LauncherServer.class.getName());
0084 private static final String THREAD_NAME_FMT = "LauncherServer-%d";
0085 private static final long DEFAULT_CONNECT_TIMEOUT = 10000L;
0086
0087
0088 private static final SecureRandom RND = new SecureRandom();
0089
0090 private static volatile LauncherServer serverInstance;
0091
0092 static synchronized LauncherServer getOrCreateServer() throws IOException {
0093 LauncherServer server;
0094 do {
0095 server = serverInstance != null ? serverInstance : new LauncherServer();
0096 } while (!server.running);
0097
0098 server.ref();
0099 serverInstance = server;
0100 return server;
0101 }
0102
0103
0104 static synchronized LauncherServer getServer() {
0105 return serverInstance;
0106 }
0107
0108 private final AtomicLong refCount;
0109 private final AtomicLong threadIds;
0110 private final ConcurrentMap<String, AbstractAppHandle> secretToPendingApps;
0111 private final List<ServerConnection> clients;
0112 private final ServerSocket server;
0113 private final Thread serverThread;
0114 private final ThreadFactory factory;
0115 private final Timer timeoutTimer;
0116
0117 private volatile boolean running;
0118
0119 private LauncherServer() throws IOException {
0120 this.refCount = new AtomicLong(0);
0121
0122 ServerSocket server = new ServerSocket();
0123 try {
0124 server.setReuseAddress(true);
0125 server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
0126
0127 this.clients = new ArrayList<>();
0128 this.threadIds = new AtomicLong();
0129 this.factory = new NamedThreadFactory(THREAD_NAME_FMT);
0130 this.secretToPendingApps = new ConcurrentHashMap<>();
0131 this.timeoutTimer = new Timer("LauncherServer-TimeoutTimer", true);
0132 this.server = server;
0133 this.running = true;
0134
0135 this.serverThread = factory.newThread(this::acceptConnections);
0136 serverThread.start();
0137 } catch (IOException ioe) {
0138 close();
0139 throw ioe;
0140 } catch (Exception e) {
0141 close();
0142 throw new IOException(e);
0143 }
0144 }
0145
0146
0147
0148
0149
0150 synchronized String registerHandle(AbstractAppHandle handle) {
0151 String secret = createSecret();
0152 secretToPendingApps.put(secret, handle);
0153 return secret;
0154 }
0155
0156 @Override
0157 public void close() throws IOException {
0158 synchronized (this) {
0159 if (!running) {
0160 return;
0161 }
0162 running = false;
0163 }
0164
0165 synchronized(LauncherServer.class) {
0166 serverInstance = null;
0167 }
0168
0169 timeoutTimer.cancel();
0170 server.close();
0171 synchronized (clients) {
0172 List<ServerConnection> copy = new ArrayList<>(clients);
0173 clients.clear();
0174 for (ServerConnection client : copy) {
0175 client.close();
0176 }
0177 }
0178
0179 if (serverThread != null) {
0180 try {
0181 serverThread.join();
0182 } catch (InterruptedException ie) {
0183
0184 }
0185 }
0186 }
0187
0188 void ref() {
0189 refCount.incrementAndGet();
0190 }
0191
0192 void unref() {
0193 synchronized(LauncherServer.class) {
0194 if (refCount.decrementAndGet() == 0) {
0195 try {
0196 close();
0197 } catch (IOException ioe) {
0198
0199 }
0200 }
0201 }
0202 }
0203
0204 int getPort() {
0205 return server.getLocalPort();
0206 }
0207
0208
0209
0210
0211
0212 void unregister(AbstractAppHandle handle) {
0213 for (Map.Entry<String, AbstractAppHandle> e : secretToPendingApps.entrySet()) {
0214 if (e.getValue().equals(handle)) {
0215 String secret = e.getKey();
0216 secretToPendingApps.remove(secret);
0217 break;
0218 }
0219 }
0220
0221 unref();
0222 }
0223
0224 private void acceptConnections() {
0225 try {
0226 while (running) {
0227 final Socket client = server.accept();
0228 TimerTask timeout = new TimerTask() {
0229 @Override
0230 public void run() {
0231 LOG.warning("Timed out waiting for hello message from client.");
0232 try {
0233 client.close();
0234 } catch (IOException ioe) {
0235
0236 }
0237 }
0238 };
0239 ServerConnection clientConnection = new ServerConnection(client, timeout);
0240 Thread clientThread = factory.newThread(clientConnection);
0241 clientConnection.setConnectionThread(clientThread);
0242 synchronized (clients) {
0243 clients.add(clientConnection);
0244 }
0245
0246 long timeoutMs = getConnectionTimeout();
0247
0248
0249 if (timeoutMs > 0) {
0250 timeoutTimer.schedule(timeout, timeoutMs);
0251 } else {
0252 timeout.run();
0253 }
0254
0255 clientThread.start();
0256 }
0257 } catch (IOException ioe) {
0258 if (running) {
0259 LOG.log(Level.SEVERE, "Error in accept loop.", ioe);
0260 }
0261 }
0262 }
0263
0264 private long getConnectionTimeout() {
0265 String value = SparkLauncher.launcherConfig.get(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
0266 return (value != null) ? Long.parseLong(value) : DEFAULT_CONNECT_TIMEOUT;
0267 }
0268
0269 private String createSecret() {
0270 while (true) {
0271 byte[] secret = new byte[128];
0272 RND.nextBytes(secret);
0273
0274 StringBuilder sb = new StringBuilder();
0275 for (byte b : secret) {
0276 int ival = b >= 0 ? b : Byte.MAX_VALUE - b;
0277 if (ival < 0x10) {
0278 sb.append("0");
0279 }
0280 sb.append(Integer.toHexString(ival));
0281 }
0282
0283 String secretStr = sb.toString();
0284 if (!secretToPendingApps.containsKey(secretStr)) {
0285 return secretStr;
0286 }
0287 }
0288 }
0289
0290 class ServerConnection extends LauncherConnection {
0291
0292 private TimerTask timeout;
0293 private volatile Thread connectionThread;
0294 private volatile AbstractAppHandle handle;
0295
0296 ServerConnection(Socket socket, TimerTask timeout) throws IOException {
0297 super(socket);
0298 this.timeout = timeout;
0299 }
0300
0301 void setConnectionThread(Thread t) {
0302 this.connectionThread = t;
0303 }
0304
0305 @Override
0306 protected void handle(Message msg) throws IOException {
0307 try {
0308 if (msg instanceof Hello) {
0309 timeout.cancel();
0310 timeout = null;
0311 Hello hello = (Hello) msg;
0312 AbstractAppHandle handle = secretToPendingApps.remove(hello.secret);
0313 if (handle != null) {
0314 handle.setConnection(this);
0315 handle.setState(SparkAppHandle.State.CONNECTED);
0316 this.handle = handle;
0317 } else {
0318 throw new IllegalArgumentException("Received Hello for unknown client.");
0319 }
0320 } else {
0321 String msgClassName = msg != null ? msg.getClass().getName() : "no message";
0322 if (handle == null) {
0323 throw new IllegalArgumentException("Expected hello, got: " + msgClassName);
0324 }
0325 if (msg instanceof SetAppId) {
0326 SetAppId set = (SetAppId) msg;
0327 handle.setAppId(set.appId);
0328 } else if (msg instanceof SetState) {
0329 handle.setState(((SetState)msg).state);
0330 } else {
0331 throw new IllegalArgumentException("Invalid message: " + msgClassName);
0332 }
0333 }
0334 } catch (Exception e) {
0335 LOG.log(Level.INFO, "Error handling message from client.", e);
0336 if (timeout != null) {
0337 timeout.cancel();
0338 }
0339 close();
0340 if (handle != null) {
0341 handle.dispose();
0342 }
0343 } finally {
0344 timeoutTimer.purge();
0345 }
0346 }
0347
0348 @Override
0349 public void close() throws IOException {
0350 if (!isOpen()) {
0351 return;
0352 }
0353
0354 synchronized (clients) {
0355 clients.remove(this);
0356 }
0357
0358 super.close();
0359 }
0360
0361
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373 public void waitForClose() throws IOException {
0374 Thread connThread = this.connectionThread;
0375 if (Thread.currentThread() != connThread) {
0376 try {
0377 connThread.join(getConnectionTimeout());
0378 } catch (InterruptedException ie) {
0379
0380 }
0381
0382 if (connThread.isAlive()) {
0383 LOG.log(Level.WARNING, "Timed out waiting for child connection to close.");
0384 close();
0385 }
0386 }
0387 }
0388
0389 }
0390
0391 }