0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.client;
0019
0020 import java.io.Closeable;
0021 import java.io.IOException;
0022 import java.net.InetSocketAddress;
0023 import java.net.SocketAddress;
0024 import java.util.List;
0025 import java.util.Random;
0026 import java.util.concurrent.ConcurrentHashMap;
0027 import java.util.concurrent.atomic.AtomicReference;
0028
0029 import com.codahale.metrics.MetricSet;
0030 import com.google.common.base.Preconditions;
0031 import com.google.common.base.Throwables;
0032 import com.google.common.collect.Lists;
0033 import io.netty.bootstrap.Bootstrap;
0034 import io.netty.buffer.PooledByteBufAllocator;
0035 import io.netty.channel.Channel;
0036 import io.netty.channel.ChannelFuture;
0037 import io.netty.channel.ChannelInitializer;
0038 import io.netty.channel.ChannelOption;
0039 import io.netty.channel.EventLoopGroup;
0040 import io.netty.channel.socket.SocketChannel;
0041 import org.slf4j.Logger;
0042 import org.slf4j.LoggerFactory;
0043
0044 import org.apache.spark.network.TransportContext;
0045 import org.apache.spark.network.server.TransportChannelHandler;
0046 import org.apache.spark.network.util.*;
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058 public class TransportClientFactory implements Closeable {
0059
0060
0061 private static class ClientPool {
0062 TransportClient[] clients;
0063 Object[] locks;
0064
0065 ClientPool(int size) {
0066 clients = new TransportClient[size];
0067 locks = new Object[size];
0068 for (int i = 0; i < size; i++) {
0069 locks[i] = new Object();
0070 }
0071 }
0072 }
0073
0074 private static final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
0075
0076 private final TransportContext context;
0077 private final TransportConf conf;
0078 private final List<TransportClientBootstrap> clientBootstraps;
0079 private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
0080
0081
0082 private final Random rand;
0083 private final int numConnectionsPerPeer;
0084
0085 private final Class<? extends Channel> socketChannelClass;
0086 private EventLoopGroup workerGroup;
0087 private final PooledByteBufAllocator pooledAllocator;
0088 private final NettyMemoryMetrics metrics;
0089
0090 public TransportClientFactory(
0091 TransportContext context,
0092 List<TransportClientBootstrap> clientBootstraps) {
0093 this.context = Preconditions.checkNotNull(context);
0094 this.conf = context.getConf();
0095 this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
0096 this.connectionPool = new ConcurrentHashMap<>();
0097 this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
0098 this.rand = new Random();
0099
0100 IOMode ioMode = IOMode.valueOf(conf.ioMode());
0101 this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
0102 this.workerGroup = NettyUtils.createEventLoop(
0103 ioMode,
0104 conf.clientThreads(),
0105 conf.getModuleName() + "-client");
0106 if (conf.sharedByteBufAllocators()) {
0107 this.pooledAllocator = NettyUtils.getSharedPooledByteBufAllocator(
0108 conf.preferDirectBufsForSharedByteBufAllocators(), false );
0109 } else {
0110 this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
0111 conf.preferDirectBufs(), false , conf.clientThreads());
0112 }
0113 this.metrics = new NettyMemoryMetrics(
0114 this.pooledAllocator, conf.getModuleName() + "-client", conf);
0115 }
0116
0117 public MetricSet getAllMetrics() {
0118 return metrics;
0119 }
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135 public TransportClient createClient(String remoteHost, int remotePort)
0136 throws IOException, InterruptedException {
0137
0138
0139
0140 final InetSocketAddress unresolvedAddress =
0141 InetSocketAddress.createUnresolved(remoteHost, remotePort);
0142
0143
0144 ClientPool clientPool = connectionPool.get(unresolvedAddress);
0145 if (clientPool == null) {
0146 connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
0147 clientPool = connectionPool.get(unresolvedAddress);
0148 }
0149
0150 int clientIndex = rand.nextInt(numConnectionsPerPeer);
0151 TransportClient cachedClient = clientPool.clients[clientIndex];
0152
0153 if (cachedClient != null && cachedClient.isActive()) {
0154
0155
0156
0157 TransportChannelHandler handler = cachedClient.getChannel().pipeline()
0158 .get(TransportChannelHandler.class);
0159 synchronized (handler) {
0160 handler.getResponseHandler().updateTimeOfLastRequest();
0161 }
0162
0163 if (cachedClient.isActive()) {
0164 logger.trace("Returning cached connection to {}: {}",
0165 cachedClient.getSocketAddress(), cachedClient);
0166 return cachedClient;
0167 }
0168 }
0169
0170
0171
0172 final long preResolveHost = System.nanoTime();
0173 final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
0174 final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
0175 final String resolvMsg = resolvedAddress.isUnresolved() ? "failed" : "succeed";
0176 if (hostResolveTimeMs > 2000) {
0177 logger.warn("DNS resolution {} for {} took {} ms",
0178 resolvMsg, resolvedAddress, hostResolveTimeMs);
0179 } else {
0180 logger.trace("DNS resolution {} for {} took {} ms",
0181 resolvMsg, resolvedAddress, hostResolveTimeMs);
0182 }
0183
0184 synchronized (clientPool.locks[clientIndex]) {
0185 cachedClient = clientPool.clients[clientIndex];
0186
0187 if (cachedClient != null) {
0188 if (cachedClient.isActive()) {
0189 logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
0190 return cachedClient;
0191 } else {
0192 logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
0193 }
0194 }
0195 clientPool.clients[clientIndex] = createClient(resolvedAddress);
0196 return clientPool.clients[clientIndex];
0197 }
0198 }
0199
0200
0201
0202
0203
0204
0205
0206 public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
0207 throws IOException, InterruptedException {
0208 final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
0209 return createClient(address);
0210 }
0211
0212
0213 private TransportClient createClient(InetSocketAddress address)
0214 throws IOException, InterruptedException {
0215 logger.debug("Creating new connection to {}", address);
0216
0217 Bootstrap bootstrap = new Bootstrap();
0218 bootstrap.group(workerGroup)
0219 .channel(socketChannelClass)
0220
0221 .option(ChannelOption.TCP_NODELAY, true)
0222 .option(ChannelOption.SO_KEEPALIVE, true)
0223 .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
0224 .option(ChannelOption.ALLOCATOR, pooledAllocator);
0225
0226 if (conf.receiveBuf() > 0) {
0227 bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf());
0228 }
0229
0230 if (conf.sendBuf() > 0) {
0231 bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf());
0232 }
0233
0234 final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
0235 final AtomicReference<Channel> channelRef = new AtomicReference<>();
0236
0237 bootstrap.handler(new ChannelInitializer<SocketChannel>() {
0238 @Override
0239 public void initChannel(SocketChannel ch) {
0240 TransportChannelHandler clientHandler = context.initializePipeline(ch);
0241 clientRef.set(clientHandler.getClient());
0242 channelRef.set(ch);
0243 }
0244 });
0245
0246
0247 long preConnect = System.nanoTime();
0248 ChannelFuture cf = bootstrap.connect(address);
0249 if (!cf.await(conf.connectionTimeoutMs())) {
0250 throw new IOException(
0251 String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
0252 } else if (cf.cause() != null) {
0253 throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
0254 }
0255
0256 TransportClient client = clientRef.get();
0257 Channel channel = channelRef.get();
0258 assert client != null : "Channel future completed successfully with null client";
0259
0260
0261 long preBootstrap = System.nanoTime();
0262 logger.debug("Connection to {} successful, running bootstraps...", address);
0263 try {
0264 for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
0265 clientBootstrap.doBootstrap(client, channel);
0266 }
0267 } catch (Exception e) {
0268 long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
0269 logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
0270 client.close();
0271 throw Throwables.propagate(e);
0272 }
0273 long postBootstrap = System.nanoTime();
0274
0275 logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
0276 address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
0277
0278 return client;
0279 }
0280
0281
0282 @Override
0283 public void close() {
0284
0285 for (ClientPool clientPool : connectionPool.values()) {
0286 for (int i = 0; i < clientPool.clients.length; i++) {
0287 TransportClient client = clientPool.clients[i];
0288 if (client != null) {
0289 clientPool.clients[i] = null;
0290 JavaUtils.closeQuietly(client);
0291 }
0292 }
0293 }
0294 connectionPool.clear();
0295
0296 if (workerGroup != null && !workerGroup.isShuttingDown()) {
0297 workerGroup.shutdownGracefully();
0298 }
0299 }
0300 }