0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.network.shuffle;
0019
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.util.ArrayList;
0023 import java.util.Arrays;
0024 import java.util.HashMap;
0025
0026 import com.google.common.primitives.Ints;
0027 import com.google.common.primitives.Longs;
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030
0031 import org.apache.spark.network.buffer.ManagedBuffer;
0032 import org.apache.spark.network.client.ChunkReceivedCallback;
0033 import org.apache.spark.network.client.RpcResponseCallback;
0034 import org.apache.spark.network.client.StreamCallback;
0035 import org.apache.spark.network.client.TransportClient;
0036 import org.apache.spark.network.server.OneForOneStreamManager;
0037 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0038 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
0039 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0040 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0041 import org.apache.spark.network.util.TransportConf;
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052 public class OneForOneBlockFetcher {
0053 private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
0054
0055 private final TransportClient client;
0056 private final BlockTransferMessage message;
0057 private final String[] blockIds;
0058 private final BlockFetchingListener listener;
0059 private final ChunkReceivedCallback chunkCallback;
0060 private final TransportConf transportConf;
0061 private final DownloadFileManager downloadFileManager;
0062
0063 private StreamHandle streamHandle = null;
0064
0065 public OneForOneBlockFetcher(
0066 TransportClient client,
0067 String appId,
0068 String execId,
0069 String[] blockIds,
0070 BlockFetchingListener listener,
0071 TransportConf transportConf) {
0072 this(client, appId, execId, blockIds, listener, transportConf, null);
0073 }
0074
0075 public OneForOneBlockFetcher(
0076 TransportClient client,
0077 String appId,
0078 String execId,
0079 String[] blockIds,
0080 BlockFetchingListener listener,
0081 TransportConf transportConf,
0082 DownloadFileManager downloadFileManager) {
0083 this.client = client;
0084 this.blockIds = blockIds;
0085 this.listener = listener;
0086 this.chunkCallback = new ChunkCallback();
0087 this.transportConf = transportConf;
0088 this.downloadFileManager = downloadFileManager;
0089 if (blockIds.length == 0) {
0090 throw new IllegalArgumentException("Zero-sized blockIds array");
0091 }
0092 if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
0093 this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
0094 } else {
0095 this.message = new OpenBlocks(appId, execId, blockIds);
0096 }
0097 }
0098
0099 private boolean isShuffleBlocks(String[] blockIds) {
0100 for (String blockId : blockIds) {
0101 if (!blockId.startsWith("shuffle_")) {
0102 return false;
0103 }
0104 }
0105 return true;
0106 }
0107
0108
0109
0110
0111
0112
0113 private FetchShuffleBlocks createFetchShuffleBlocksMsg(
0114 String appId, String execId, String[] blockIds) {
0115 String[] firstBlock = splitBlockId(blockIds[0]);
0116 int shuffleId = Integer.parseInt(firstBlock[1]);
0117 boolean batchFetchEnabled = firstBlock.length == 5;
0118
0119 HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
0120 for (String blockId : blockIds) {
0121 String[] blockIdParts = splitBlockId(blockId);
0122 if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
0123 throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
0124 ", got:" + blockId);
0125 }
0126 long mapId = Long.parseLong(blockIdParts[2]);
0127 if (!mapIdToReduceIds.containsKey(mapId)) {
0128 mapIdToReduceIds.put(mapId, new ArrayList<>());
0129 }
0130 mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
0131 if (batchFetchEnabled) {
0132
0133
0134
0135 assert(blockIdParts.length == 5);
0136 mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
0137 }
0138 }
0139 long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
0140 int[][] reduceIdArr = new int[mapIds.length][];
0141 for (int i = 0; i < mapIds.length; i++) {
0142 reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
0143 }
0144 return new FetchShuffleBlocks(
0145 appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
0146 }
0147
0148
0149 private String[] splitBlockId(String blockId) {
0150 String[] blockIdParts = blockId.split("_");
0151
0152
0153 if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
0154 throw new IllegalArgumentException(
0155 "Unexpected shuffle block id format: " + blockId);
0156 }
0157 return blockIdParts;
0158 }
0159
0160
0161 private class ChunkCallback implements ChunkReceivedCallback {
0162 @Override
0163 public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
0164
0165 listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
0166 }
0167
0168 @Override
0169 public void onFailure(int chunkIndex, Throwable e) {
0170
0171 String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
0172 failRemainingBlocks(remainingBlockIds, e);
0173 }
0174 }
0175
0176
0177
0178
0179
0180
0181 public void start() {
0182 client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() {
0183 @Override
0184 public void onSuccess(ByteBuffer response) {
0185 try {
0186 streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
0187 logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
0188
0189
0190
0191 for (int i = 0; i < streamHandle.numChunks; i++) {
0192 if (downloadFileManager != null) {
0193 client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
0194 new DownloadCallback(i));
0195 } else {
0196 client.fetchChunk(streamHandle.streamId, i, chunkCallback);
0197 }
0198 }
0199 } catch (Exception e) {
0200 logger.error("Failed while starting block fetches after success", e);
0201 failRemainingBlocks(blockIds, e);
0202 }
0203 }
0204
0205 @Override
0206 public void onFailure(Throwable e) {
0207 logger.error("Failed while starting block fetches", e);
0208 failRemainingBlocks(blockIds, e);
0209 }
0210 });
0211 }
0212
0213
0214 private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
0215 for (String blockId : failedBlockIds) {
0216 try {
0217 listener.onBlockFetchFailure(blockId, e);
0218 } catch (Exception e2) {
0219 logger.error("Error in block fetch failure callback", e2);
0220 }
0221 }
0222 }
0223
0224 private class DownloadCallback implements StreamCallback {
0225
0226 private DownloadFileWritableChannel channel = null;
0227 private DownloadFile targetFile = null;
0228 private int chunkIndex;
0229
0230 DownloadCallback(int chunkIndex) throws IOException {
0231 this.targetFile = downloadFileManager.createTempFile(transportConf);
0232 this.channel = targetFile.openForWriting();
0233 this.chunkIndex = chunkIndex;
0234 }
0235
0236 @Override
0237 public void onData(String streamId, ByteBuffer buf) throws IOException {
0238 while (buf.hasRemaining()) {
0239 channel.write(buf);
0240 }
0241 }
0242
0243 @Override
0244 public void onComplete(String streamId) throws IOException {
0245 listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead());
0246 if (!downloadFileManager.registerTempFileToClean(targetFile)) {
0247 targetFile.delete();
0248 }
0249 }
0250
0251 @Override
0252 public void onFailure(String streamId, Throwable cause) throws IOException {
0253 channel.close();
0254
0255 String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
0256 failRemainingBlocks(remainingBlockIds, cause);
0257 targetFile.delete();
0258 }
0259 }
0260 }