0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.server;
0019
0020 import java.util.Iterator;
0021 import java.util.Map;
0022 import java.util.Random;
0023 import java.util.concurrent.ConcurrentHashMap;
0024 import java.util.concurrent.atomic.AtomicLong;
0025
0026 import com.google.common.annotations.VisibleForTesting;
0027 import com.google.common.base.Preconditions;
0028 import io.netty.channel.Channel;
0029 import org.apache.commons.lang3.tuple.ImmutablePair;
0030 import org.apache.commons.lang3.tuple.Pair;
0031 import org.slf4j.Logger;
0032 import org.slf4j.LoggerFactory;
0033
0034 import org.apache.spark.network.buffer.ManagedBuffer;
0035 import org.apache.spark.network.client.TransportClient;
0036
0037
0038
0039
0040
0041 public class OneForOneStreamManager extends StreamManager {
0042 private static final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
0043
0044 private final AtomicLong nextStreamId;
0045 private final ConcurrentHashMap<Long, StreamState> streams;
0046
0047
0048 private static class StreamState {
0049 final String appId;
0050 final Iterator<ManagedBuffer> buffers;
0051
0052
0053 final Channel associatedChannel;
0054
0055
0056
0057 int curChunk = 0;
0058
0059
0060 final AtomicLong chunksBeingTransferred = new AtomicLong(0L);
0061
0062 StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
0063 this.appId = appId;
0064 this.buffers = Preconditions.checkNotNull(buffers);
0065 this.associatedChannel = channel;
0066 }
0067 }
0068
0069 public OneForOneStreamManager() {
0070
0071
0072 nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
0073 streams = new ConcurrentHashMap<>();
0074 }
0075
0076 @Override
0077 public ManagedBuffer getChunk(long streamId, int chunkIndex) {
0078 StreamState state = streams.get(streamId);
0079 if (chunkIndex != state.curChunk) {
0080 throw new IllegalStateException(String.format(
0081 "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk));
0082 } else if (!state.buffers.hasNext()) {
0083 throw new IllegalStateException(String.format(
0084 "Requested chunk index beyond end %s", chunkIndex));
0085 }
0086 state.curChunk += 1;
0087 ManagedBuffer nextChunk = state.buffers.next();
0088
0089 if (!state.buffers.hasNext()) {
0090 logger.trace("Removing stream id {}", streamId);
0091 streams.remove(streamId);
0092 }
0093
0094 return nextChunk;
0095 }
0096
0097 @Override
0098 public ManagedBuffer openStream(String streamChunkId) {
0099 Pair<Long, Integer> streamChunkIdPair = parseStreamChunkId(streamChunkId);
0100 return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight());
0101 }
0102
0103 public static String genStreamChunkId(long streamId, int chunkId) {
0104 return String.format("%d_%d", streamId, chunkId);
0105 }
0106
0107
0108
0109 public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) {
0110 String[] array = streamChunkId.split("_");
0111 assert array.length == 2:
0112 "Stream id and chunk index should be specified.";
0113 long streamId = Long.valueOf(array[0]);
0114 int chunkIndex = Integer.valueOf(array[1]);
0115 return ImmutablePair.of(streamId, chunkIndex);
0116 }
0117
0118 @Override
0119 public void connectionTerminated(Channel channel) {
0120 RuntimeException failedToReleaseBufferException = null;
0121
0122
0123 for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
0124 StreamState state = entry.getValue();
0125 if (state.associatedChannel == channel) {
0126 streams.remove(entry.getKey());
0127
0128 try {
0129
0130 while (state.buffers.hasNext()) {
0131 ManagedBuffer buffer = state.buffers.next();
0132 if (buffer != null) {
0133 buffer.release();
0134 }
0135 }
0136 } catch (RuntimeException e) {
0137 if (failedToReleaseBufferException == null) {
0138 failedToReleaseBufferException = e;
0139 } else {
0140 logger.error("Exception trying to release remaining StreamState buffers", e);
0141 }
0142 }
0143 }
0144 }
0145
0146 if (failedToReleaseBufferException != null) {
0147 throw failedToReleaseBufferException;
0148 }
0149 }
0150
0151 @Override
0152 public void checkAuthorization(TransportClient client, long streamId) {
0153 if (client.getClientId() != null) {
0154 StreamState state = streams.get(streamId);
0155 Preconditions.checkArgument(state != null, "Unknown stream ID.");
0156 if (!client.getClientId().equals(state.appId)) {
0157 throw new SecurityException(String.format(
0158 "Client %s not authorized to read stream %d (app %s).",
0159 client.getClientId(),
0160 streamId,
0161 state.appId));
0162 }
0163 }
0164 }
0165
0166 @Override
0167 public void chunkBeingSent(long streamId) {
0168 StreamState streamState = streams.get(streamId);
0169 if (streamState != null) {
0170 streamState.chunksBeingTransferred.incrementAndGet();
0171 }
0172
0173 }
0174
0175 @Override
0176 public void streamBeingSent(String streamId) {
0177 chunkBeingSent(parseStreamChunkId(streamId).getLeft());
0178 }
0179
0180 @Override
0181 public void chunkSent(long streamId) {
0182 StreamState streamState = streams.get(streamId);
0183 if (streamState != null) {
0184 streamState.chunksBeingTransferred.decrementAndGet();
0185 }
0186 }
0187
0188 @Override
0189 public void streamSent(String streamId) {
0190 chunkSent(OneForOneStreamManager.parseStreamChunkId(streamId).getLeft());
0191 }
0192
0193 @Override
0194 public long chunksBeingTransferred() {
0195 long sum = 0L;
0196 for (StreamState streamState: streams.values()) {
0197 sum += streamState.chunksBeingTransferred.get();
0198 }
0199 return sum;
0200 }
0201
0202
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212
0213
0214
0215 public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
0216 long myStreamId = nextStreamId.getAndIncrement();
0217 streams.put(myStreamId, new StreamState(appId, buffers, channel));
0218 return myStreamId;
0219 }
0220
0221 @VisibleForTesting
0222 public int numStreamStates() {
0223 return streams.size();
0224 }
0225 }