Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.client;
0019 
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.net.InetSocketAddress;
0023 import java.net.SocketAddress;
0024 import java.util.List;
0025 import java.util.Random;
0026 import java.util.concurrent.ConcurrentHashMap;
0027 import java.util.concurrent.atomic.AtomicReference;
0028 
0029 import com.codahale.metrics.MetricSet;
0030 import com.google.common.base.Preconditions;
0031 import com.google.common.base.Throwables;
0032 import com.google.common.collect.Lists;
0033 import io.netty.bootstrap.Bootstrap;
0034 import io.netty.buffer.PooledByteBufAllocator;
0035 import io.netty.channel.Channel;
0036 import io.netty.channel.ChannelFuture;
0037 import io.netty.channel.ChannelInitializer;
0038 import io.netty.channel.ChannelOption;
0039 import io.netty.channel.EventLoopGroup;
0040 import io.netty.channel.socket.SocketChannel;
0041 import org.slf4j.Logger;
0042 import org.slf4j.LoggerFactory;
0043 
0044 import org.apache.spark.network.TransportContext;
0045 import org.apache.spark.network.server.TransportChannelHandler;
0046 import org.apache.spark.network.util.*;
0047 
0048 /**
0049  * Factory for creating {@link TransportClient}s by using createClient.
0050  *
0051  * The factory maintains a connection pool to other hosts and should return the same
0052  * TransportClient for the same remote host. It also shares a single worker thread pool for
0053  * all TransportClients.
0054  *
0055  * TransportClients will be reused whenever possible. Prior to completing the creation of a new
0056  * TransportClient, all given {@link TransportClientBootstrap}s will be run.
0057  */
0058 public class TransportClientFactory implements Closeable {
0059 
0060   /** A simple data structure to track the pool of clients between two peer nodes. */
0061   private static class ClientPool {
0062     TransportClient[] clients;
0063     Object[] locks;
0064 
0065     ClientPool(int size) {
0066       clients = new TransportClient[size];
0067       locks = new Object[size];
0068       for (int i = 0; i < size; i++) {
0069         locks[i] = new Object();
0070       }
0071     }
0072   }
0073 
0074   private static final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
0075 
0076   private final TransportContext context;
0077   private final TransportConf conf;
0078   private final List<TransportClientBootstrap> clientBootstraps;
0079   private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
0080 
0081   /** Random number generator for picking connections between peers. */
0082   private final Random rand;
0083   private final int numConnectionsPerPeer;
0084 
0085   private final Class<? extends Channel> socketChannelClass;
0086   private EventLoopGroup workerGroup;
0087   private final PooledByteBufAllocator pooledAllocator;
0088   private final NettyMemoryMetrics metrics;
0089 
0090   public TransportClientFactory(
0091       TransportContext context,
0092       List<TransportClientBootstrap> clientBootstraps) {
0093     this.context = Preconditions.checkNotNull(context);
0094     this.conf = context.getConf();
0095     this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
0096     this.connectionPool = new ConcurrentHashMap<>();
0097     this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
0098     this.rand = new Random();
0099 
0100     IOMode ioMode = IOMode.valueOf(conf.ioMode());
0101     this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
0102     this.workerGroup = NettyUtils.createEventLoop(
0103         ioMode,
0104         conf.clientThreads(),
0105         conf.getModuleName() + "-client");
0106     if (conf.sharedByteBufAllocators()) {
0107       this.pooledAllocator = NettyUtils.getSharedPooledByteBufAllocator(
0108           conf.preferDirectBufsForSharedByteBufAllocators(), false /* allowCache */);
0109     } else {
0110       this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
0111           conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads());
0112     }
0113     this.metrics = new NettyMemoryMetrics(
0114       this.pooledAllocator, conf.getModuleName() + "-client", conf);
0115   }
0116 
0117   public MetricSet getAllMetrics() {
0118     return metrics;
0119   }
0120 
0121   /**
0122    * Create a {@link TransportClient} connecting to the given remote host / port.
0123    *
0124    * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
0125    * and randomly picks one to use. If no client was previously created in the randomly selected
0126    * spot, this function creates a new client and places it there.
0127    *
0128    * Prior to the creation of a new TransportClient, we will execute all
0129    * {@link TransportClientBootstrap}s that are registered with this factory.
0130    *
0131    * This blocks until a connection is successfully established and fully bootstrapped.
0132    *
0133    * Concurrency: This method is safe to call from multiple threads.
0134    */
0135   public TransportClient createClient(String remoteHost, int remotePort)
0136       throws IOException, InterruptedException {
0137     // Get connection from the connection pool first.
0138     // If it is not found or not active, create a new one.
0139     // Use unresolved address here to avoid DNS resolution each time we creates a client.
0140     final InetSocketAddress unresolvedAddress =
0141       InetSocketAddress.createUnresolved(remoteHost, remotePort);
0142 
0143     // Create the ClientPool if we don't have it yet.
0144     ClientPool clientPool = connectionPool.get(unresolvedAddress);
0145     if (clientPool == null) {
0146       connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
0147       clientPool = connectionPool.get(unresolvedAddress);
0148     }
0149 
0150     int clientIndex = rand.nextInt(numConnectionsPerPeer);
0151     TransportClient cachedClient = clientPool.clients[clientIndex];
0152 
0153     if (cachedClient != null && cachedClient.isActive()) {
0154       // Make sure that the channel will not timeout by updating the last use time of the
0155       // handler. Then check that the client is still alive, in case it timed out before
0156       // this code was able to update things.
0157       TransportChannelHandler handler = cachedClient.getChannel().pipeline()
0158         .get(TransportChannelHandler.class);
0159       synchronized (handler) {
0160         handler.getResponseHandler().updateTimeOfLastRequest();
0161       }
0162 
0163       if (cachedClient.isActive()) {
0164         logger.trace("Returning cached connection to {}: {}",
0165           cachedClient.getSocketAddress(), cachedClient);
0166         return cachedClient;
0167       }
0168     }
0169 
0170     // If we reach here, we don't have an existing connection open. Let's create a new one.
0171     // Multiple threads might race here to create new connections. Keep only one of them active.
0172     final long preResolveHost = System.nanoTime();
0173     final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
0174     final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
0175     final String resolvMsg = resolvedAddress.isUnresolved() ? "failed" : "succeed";
0176     if (hostResolveTimeMs > 2000) {
0177       logger.warn("DNS resolution {} for {} took {} ms",
0178           resolvMsg, resolvedAddress, hostResolveTimeMs);
0179     } else {
0180       logger.trace("DNS resolution {} for {} took {} ms",
0181           resolvMsg, resolvedAddress, hostResolveTimeMs);
0182     }
0183 
0184     synchronized (clientPool.locks[clientIndex]) {
0185       cachedClient = clientPool.clients[clientIndex];
0186 
0187       if (cachedClient != null) {
0188         if (cachedClient.isActive()) {
0189           logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
0190           return cachedClient;
0191         } else {
0192           logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
0193         }
0194       }
0195       clientPool.clients[clientIndex] = createClient(resolvedAddress);
0196       return clientPool.clients[clientIndex];
0197     }
0198   }
0199 
0200   /**
0201    * Create a completely new {@link TransportClient} to the given remote host / port.
0202    * This connection is not pooled.
0203    *
0204    * As with {@link #createClient(String, int)}, this method is blocking.
0205    */
0206   public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
0207       throws IOException, InterruptedException {
0208     final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
0209     return createClient(address);
0210   }
0211 
0212   /** Create a completely new {@link TransportClient} to the remote address. */
0213   private TransportClient createClient(InetSocketAddress address)
0214       throws IOException, InterruptedException {
0215     logger.debug("Creating new connection to {}", address);
0216 
0217     Bootstrap bootstrap = new Bootstrap();
0218     bootstrap.group(workerGroup)
0219       .channel(socketChannelClass)
0220       // Disable Nagle's Algorithm since we don't want packets to wait
0221       .option(ChannelOption.TCP_NODELAY, true)
0222       .option(ChannelOption.SO_KEEPALIVE, true)
0223       .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
0224       .option(ChannelOption.ALLOCATOR, pooledAllocator);
0225 
0226     if (conf.receiveBuf() > 0) {
0227       bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf());
0228     }
0229 
0230     if (conf.sendBuf() > 0) {
0231       bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf());
0232     }
0233 
0234     final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
0235     final AtomicReference<Channel> channelRef = new AtomicReference<>();
0236 
0237     bootstrap.handler(new ChannelInitializer<SocketChannel>() {
0238       @Override
0239       public void initChannel(SocketChannel ch) {
0240         TransportChannelHandler clientHandler = context.initializePipeline(ch);
0241         clientRef.set(clientHandler.getClient());
0242         channelRef.set(ch);
0243       }
0244     });
0245 
0246     // Connect to the remote server
0247     long preConnect = System.nanoTime();
0248     ChannelFuture cf = bootstrap.connect(address);
0249     if (!cf.await(conf.connectionTimeoutMs())) {
0250       throw new IOException(
0251         String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
0252     } else if (cf.cause() != null) {
0253       throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
0254     }
0255 
0256     TransportClient client = clientRef.get();
0257     Channel channel = channelRef.get();
0258     assert client != null : "Channel future completed successfully with null client";
0259 
0260     // Execute any client bootstraps synchronously before marking the Client as successful.
0261     long preBootstrap = System.nanoTime();
0262     logger.debug("Connection to {} successful, running bootstraps...", address);
0263     try {
0264       for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
0265         clientBootstrap.doBootstrap(client, channel);
0266       }
0267     } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
0268       long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
0269       logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
0270       client.close();
0271       throw Throwables.propagate(e);
0272     }
0273     long postBootstrap = System.nanoTime();
0274 
0275     logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
0276       address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
0277 
0278     return client;
0279   }
0280 
0281   /** Close all connections in the connection pool, and shutdown the worker thread pool. */
0282   @Override
0283   public void close() {
0284     // Go through all clients and close them if they are active.
0285     for (ClientPool clientPool : connectionPool.values()) {
0286       for (int i = 0; i < clientPool.clients.length; i++) {
0287         TransportClient client = clientPool.clients[i];
0288         if (client != null) {
0289           clientPool.clients[i] = null;
0290           JavaUtils.closeQuietly(client);
0291         }
0292       }
0293     }
0294     connectionPool.clear();
0295 
0296     if (workerGroup != null && !workerGroup.isShuttingDown()) {
0297       workerGroup.shutdownGracefully();
0298     }
0299   }
0300 }