Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.network;
0019 
0020 import java.io.File;
0021 import java.io.RandomAccessFile;
0022 import java.nio.ByteBuffer;
0023 import java.util.Arrays;
0024 import java.util.Collections;
0025 import java.util.HashSet;
0026 import java.util.LinkedList;
0027 import java.util.List;
0028 import java.util.Random;
0029 import java.util.Set;
0030 import java.util.concurrent.Semaphore;
0031 import java.util.concurrent.TimeUnit;
0032 
0033 import com.google.common.collect.Sets;
0034 import com.google.common.io.Closeables;
0035 import org.junit.AfterClass;
0036 import org.junit.BeforeClass;
0037 import org.junit.Test;
0038 
0039 import static org.junit.Assert.*;
0040 
0041 import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
0042 import org.apache.spark.network.buffer.ManagedBuffer;
0043 import org.apache.spark.network.buffer.NioManagedBuffer;
0044 import org.apache.spark.network.client.ChunkReceivedCallback;
0045 import org.apache.spark.network.client.RpcResponseCallback;
0046 import org.apache.spark.network.client.TransportClient;
0047 import org.apache.spark.network.client.TransportClientFactory;
0048 import org.apache.spark.network.server.RpcHandler;
0049 import org.apache.spark.network.server.TransportServer;
0050 import org.apache.spark.network.server.StreamManager;
0051 import org.apache.spark.network.util.MapConfigProvider;
0052 import org.apache.spark.network.util.TransportConf;
0053 
0054 public class ChunkFetchIntegrationSuite {
0055   static final long STREAM_ID = 1;
0056   static final int BUFFER_CHUNK_INDEX = 0;
0057   static final int FILE_CHUNK_INDEX = 1;
0058 
0059   static TransportContext context;
0060   static TransportServer server;
0061   static TransportClientFactory clientFactory;
0062   static StreamManager streamManager;
0063   static File testFile;
0064 
0065   static ManagedBuffer bufferChunk;
0066   static ManagedBuffer fileChunk;
0067 
0068   @BeforeClass
0069   public static void setUp() throws Exception {
0070     int bufSize = 100000;
0071     final ByteBuffer buf = ByteBuffer.allocate(bufSize);
0072     for (int i = 0; i < bufSize; i ++) {
0073       buf.put((byte) i);
0074     }
0075     buf.flip();
0076     bufferChunk = new NioManagedBuffer(buf);
0077 
0078     testFile = File.createTempFile("shuffle-test-file", "txt");
0079     testFile.deleteOnExit();
0080     RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
0081     boolean shouldSuppressIOException = true;
0082     try {
0083       byte[] fileContent = new byte[1024];
0084       new Random().nextBytes(fileContent);
0085       fp.write(fileContent);
0086       shouldSuppressIOException = false;
0087     } finally {
0088       Closeables.close(fp, shouldSuppressIOException);
0089     }
0090 
0091     final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
0092     fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
0093 
0094     streamManager = new StreamManager() {
0095       @Override
0096       public ManagedBuffer getChunk(long streamId, int chunkIndex) {
0097         assertEquals(STREAM_ID, streamId);
0098         if (chunkIndex == BUFFER_CHUNK_INDEX) {
0099           return new NioManagedBuffer(buf);
0100         } else if (chunkIndex == FILE_CHUNK_INDEX) {
0101           return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
0102         } else {
0103           throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
0104         }
0105       }
0106     };
0107     RpcHandler handler = new RpcHandler() {
0108       @Override
0109       public void receive(
0110           TransportClient client,
0111           ByteBuffer message,
0112           RpcResponseCallback callback) {
0113         throw new UnsupportedOperationException();
0114       }
0115 
0116       @Override
0117       public StreamManager getStreamManager() {
0118         return streamManager;
0119       }
0120     };
0121     context = new TransportContext(conf, handler);
0122     server = context.createServer();
0123     clientFactory = context.createClientFactory();
0124   }
0125 
0126   @AfterClass
0127   public static void tearDown() {
0128     bufferChunk.release();
0129     server.close();
0130     clientFactory.close();
0131     context.close();
0132     testFile.delete();
0133   }
0134 
0135   static class FetchResult {
0136     public Set<Integer> successChunks;
0137     public Set<Integer> failedChunks;
0138     public List<ManagedBuffer> buffers;
0139 
0140     public void releaseBuffers() {
0141       for (ManagedBuffer buffer : buffers) {
0142         buffer.release();
0143       }
0144     }
0145   }
0146 
0147   private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
0148     final FetchResult res = new FetchResult();
0149 
0150     try (TransportClient client =
0151       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())) {
0152       final Semaphore sem = new Semaphore(0);
0153 
0154       res.successChunks = Collections.synchronizedSet(new HashSet<>());
0155       res.failedChunks = Collections.synchronizedSet(new HashSet<>());
0156       res.buffers = Collections.synchronizedList(new LinkedList<>());
0157 
0158       ChunkReceivedCallback callback = new ChunkReceivedCallback() {
0159         @Override
0160         public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
0161           buffer.retain();
0162           res.successChunks.add(chunkIndex);
0163           res.buffers.add(buffer);
0164           sem.release();
0165         }
0166 
0167         @Override
0168         public void onFailure(int chunkIndex, Throwable e) {
0169           res.failedChunks.add(chunkIndex);
0170           sem.release();
0171         }
0172       };
0173 
0174       for (int chunkIndex : chunkIndices) {
0175         client.fetchChunk(STREAM_ID, chunkIndex, callback);
0176       }
0177       if (!sem.tryAcquire(chunkIndices.size(), 60, TimeUnit.SECONDS)) {
0178         fail("Timeout getting response from the server");
0179       }
0180     }
0181     return res;
0182   }
0183 
0184   @Test
0185   public void fetchBufferChunk() throws Exception {
0186     FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX));
0187     assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks);
0188     assertTrue(res.failedChunks.isEmpty());
0189     assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers);
0190     res.releaseBuffers();
0191   }
0192 
0193   @Test
0194   public void fetchFileChunk() throws Exception {
0195     FetchResult res = fetchChunks(Arrays.asList(FILE_CHUNK_INDEX));
0196     assertEquals(Sets.newHashSet(FILE_CHUNK_INDEX), res.successChunks);
0197     assertTrue(res.failedChunks.isEmpty());
0198     assertBufferListsEqual(Arrays.asList(fileChunk), res.buffers);
0199     res.releaseBuffers();
0200   }
0201 
0202   @Test
0203   public void fetchNonExistentChunk() throws Exception {
0204     FetchResult res = fetchChunks(Arrays.asList(12345));
0205     assertTrue(res.successChunks.isEmpty());
0206     assertEquals(Sets.newHashSet(12345), res.failedChunks);
0207     assertTrue(res.buffers.isEmpty());
0208   }
0209 
0210   @Test
0211   public void fetchBothChunks() throws Exception {
0212     FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
0213     assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX), res.successChunks);
0214     assertTrue(res.failedChunks.isEmpty());
0215     assertBufferListsEqual(Arrays.asList(bufferChunk, fileChunk), res.buffers);
0216     res.releaseBuffers();
0217   }
0218 
0219   @Test
0220   public void fetchChunkAndNonExistent() throws Exception {
0221     FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, 12345));
0222     assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks);
0223     assertEquals(Sets.newHashSet(12345), res.failedChunks);
0224     assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers);
0225     res.releaseBuffers();
0226   }
0227 
0228   private static void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
0229       throws Exception {
0230     assertEquals(list0.size(), list1.size());
0231     for (int i = 0; i < list0.size(); i ++) {
0232       assertBuffersEqual(list0.get(i), list1.get(i));
0233     }
0234   }
0235 
0236   private static void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1)
0237       throws Exception {
0238     ByteBuffer nio0 = buffer0.nioByteBuffer();
0239     ByteBuffer nio1 = buffer1.nioByteBuffer();
0240 
0241     int len = nio0.remaining();
0242     assertEquals(nio0.remaining(), nio1.remaining());
0243     for (int i = 0; i < len; i ++) {
0244       assertEquals(nio0.get(), nio1.get());
0245     }
0246   }
0247 }