Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network.shuffle;
0019 
0020 import java.io.IOException;
0021 import java.nio.ByteBuffer;
0022 import java.util.ArrayList;
0023 import java.util.Arrays;
0024 import java.util.HashMap;
0025 
0026 import com.google.common.primitives.Ints;
0027 import com.google.common.primitives.Longs;
0028 import org.slf4j.Logger;
0029 import org.slf4j.LoggerFactory;
0030 
0031 import org.apache.spark.network.buffer.ManagedBuffer;
0032 import org.apache.spark.network.client.ChunkReceivedCallback;
0033 import org.apache.spark.network.client.RpcResponseCallback;
0034 import org.apache.spark.network.client.StreamCallback;
0035 import org.apache.spark.network.client.TransportClient;
0036 import org.apache.spark.network.server.OneForOneStreamManager;
0037 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
0038 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
0039 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
0040 import org.apache.spark.network.shuffle.protocol.StreamHandle;
0041 import org.apache.spark.network.util.TransportConf;
0042 
0043 /**
0044  * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
0045  * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC
0046  * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle,
0047  * and Java serialization is used.
0048  *
0049  * Note that this typically corresponds to a
0050  * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side.
0051  */
0052 public class OneForOneBlockFetcher {
0053   private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
0054 
0055   private final TransportClient client;
0056   private final BlockTransferMessage message;
0057   private final String[] blockIds;
0058   private final BlockFetchingListener listener;
0059   private final ChunkReceivedCallback chunkCallback;
0060   private final TransportConf transportConf;
0061   private final DownloadFileManager downloadFileManager;
0062 
0063   private StreamHandle streamHandle = null;
0064 
0065   public OneForOneBlockFetcher(
0066     TransportClient client,
0067     String appId,
0068     String execId,
0069     String[] blockIds,
0070     BlockFetchingListener listener,
0071     TransportConf transportConf) {
0072     this(client, appId, execId, blockIds, listener, transportConf, null);
0073   }
0074 
0075   public OneForOneBlockFetcher(
0076       TransportClient client,
0077       String appId,
0078       String execId,
0079       String[] blockIds,
0080       BlockFetchingListener listener,
0081       TransportConf transportConf,
0082       DownloadFileManager downloadFileManager) {
0083     this.client = client;
0084     this.blockIds = blockIds;
0085     this.listener = listener;
0086     this.chunkCallback = new ChunkCallback();
0087     this.transportConf = transportConf;
0088     this.downloadFileManager = downloadFileManager;
0089     if (blockIds.length == 0) {
0090       throw new IllegalArgumentException("Zero-sized blockIds array");
0091     }
0092     if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
0093       this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
0094     } else {
0095       this.message = new OpenBlocks(appId, execId, blockIds);
0096     }
0097   }
0098 
0099   private boolean isShuffleBlocks(String[] blockIds) {
0100     for (String blockId : blockIds) {
0101       if (!blockId.startsWith("shuffle_")) {
0102         return false;
0103       }
0104     }
0105     return true;
0106   }
0107 
0108   /**
0109    * Analyze the pass in blockIds and create FetchShuffleBlocks message.
0110    * The blockIds has been sorted by mapId and reduceId. It's produced in
0111    * org.apache.spark.MapOutputTracker.convertMapStatuses.
0112    */
0113   private FetchShuffleBlocks createFetchShuffleBlocksMsg(
0114       String appId, String execId, String[] blockIds) {
0115     String[] firstBlock = splitBlockId(blockIds[0]);
0116     int shuffleId = Integer.parseInt(firstBlock[1]);
0117     boolean batchFetchEnabled = firstBlock.length == 5;
0118 
0119     HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
0120     for (String blockId : blockIds) {
0121       String[] blockIdParts = splitBlockId(blockId);
0122       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
0123         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
0124           ", got:" + blockId);
0125       }
0126       long mapId = Long.parseLong(blockIdParts[2]);
0127       if (!mapIdToReduceIds.containsKey(mapId)) {
0128         mapIdToReduceIds.put(mapId, new ArrayList<>());
0129       }
0130       mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
0131       if (batchFetchEnabled) {
0132         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
0133         // FetchShuffleBlocks to store the start and end reduce id for range
0134         // [startReduceId, endReduceId).
0135         assert(blockIdParts.length == 5);
0136         mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
0137       }
0138     }
0139     long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
0140     int[][] reduceIdArr = new int[mapIds.length][];
0141     for (int i = 0; i < mapIds.length; i++) {
0142       reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
0143     }
0144     return new FetchShuffleBlocks(
0145       appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
0146   }
0147 
0148   /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
0149   private String[] splitBlockId(String blockId) {
0150     String[] blockIdParts = blockId.split("_");
0151     // For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId.
0152     // For single block id, the format contains shuffleId, mapId, educeId.
0153     if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
0154       throw new IllegalArgumentException(
0155         "Unexpected shuffle block id format: " + blockId);
0156     }
0157     return blockIdParts;
0158   }
0159 
0160   /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
0161   private class ChunkCallback implements ChunkReceivedCallback {
0162     @Override
0163     public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
0164       // On receipt of a chunk, pass it upwards as a block.
0165       listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
0166     }
0167 
0168     @Override
0169     public void onFailure(int chunkIndex, Throwable e) {
0170       // On receipt of a failure, fail every block from chunkIndex onwards.
0171       String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
0172       failRemainingBlocks(remainingBlockIds, e);
0173     }
0174   }
0175 
0176   /**
0177    * Begins the fetching process, calling the listener with every block fetched.
0178    * The given message will be serialized with the Java serializer, and the RPC must return a
0179    * {@link StreamHandle}. We will send all fetch requests immediately, without throttling.
0180    */
0181   public void start() {
0182     client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() {
0183       @Override
0184       public void onSuccess(ByteBuffer response) {
0185         try {
0186           streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
0187           logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
0188 
0189           // Immediately request all chunks -- we expect that the total size of the request is
0190           // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
0191           for (int i = 0; i < streamHandle.numChunks; i++) {
0192             if (downloadFileManager != null) {
0193               client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
0194                 new DownloadCallback(i));
0195             } else {
0196               client.fetchChunk(streamHandle.streamId, i, chunkCallback);
0197             }
0198           }
0199         } catch (Exception e) {
0200           logger.error("Failed while starting block fetches after success", e);
0201           failRemainingBlocks(blockIds, e);
0202         }
0203       }
0204 
0205       @Override
0206       public void onFailure(Throwable e) {
0207         logger.error("Failed while starting block fetches", e);
0208         failRemainingBlocks(blockIds, e);
0209       }
0210     });
0211   }
0212 
0213   /** Invokes the "onBlockFetchFailure" callback for every listed block id. */
0214   private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
0215     for (String blockId : failedBlockIds) {
0216       try {
0217         listener.onBlockFetchFailure(blockId, e);
0218       } catch (Exception e2) {
0219         logger.error("Error in block fetch failure callback", e2);
0220       }
0221     }
0222   }
0223 
0224   private class DownloadCallback implements StreamCallback {
0225 
0226     private DownloadFileWritableChannel channel = null;
0227     private DownloadFile targetFile = null;
0228     private int chunkIndex;
0229 
0230     DownloadCallback(int chunkIndex) throws IOException {
0231       this.targetFile = downloadFileManager.createTempFile(transportConf);
0232       this.channel = targetFile.openForWriting();
0233       this.chunkIndex = chunkIndex;
0234     }
0235 
0236     @Override
0237     public void onData(String streamId, ByteBuffer buf) throws IOException {
0238       while (buf.hasRemaining()) {
0239         channel.write(buf);
0240       }
0241     }
0242 
0243     @Override
0244     public void onComplete(String streamId) throws IOException {
0245       listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead());
0246       if (!downloadFileManager.registerTempFileToClean(targetFile)) {
0247         targetFile.delete();
0248       }
0249     }
0250 
0251     @Override
0252     public void onFailure(String streamId, Throwable cause) throws IOException {
0253       channel.close();
0254       // On receipt of a failure, fail every block from chunkIndex onwards.
0255       String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
0256       failRemainingBlocks(remainingBlockIds, cause);
0257       targetFile.delete();
0258     }
0259   }
0260 }