0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.shuffle.sort.io;
0019
0020 import java.io.BufferedOutputStream;
0021 import java.io.File;
0022 import java.io.FileOutputStream;
0023 import java.io.IOException;
0024 import java.io.OutputStream;
0025 import java.nio.channels.FileChannel;
0026 import java.nio.channels.WritableByteChannel;
0027 import java.util.Optional;
0028
0029 import org.slf4j.Logger;
0030 import org.slf4j.LoggerFactory;
0031
0032 import org.apache.spark.SparkConf;
0033 import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
0034 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
0035 import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
0036 import org.apache.spark.internal.config.package$;
0037 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
0038 import org.apache.spark.util.Utils;
0039
0040
0041
0042
0043
0044
0045 public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
0046
0047 private static final Logger log =
0048 LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class);
0049
0050 private final int shuffleId;
0051 private final long mapId;
0052 private final IndexShuffleBlockResolver blockResolver;
0053 private final long[] partitionLengths;
0054 private final int bufferSize;
0055 private int lastPartitionId = -1;
0056 private long currChannelPosition;
0057 private long bytesWrittenToMergedFile = 0L;
0058
0059 private final File outputFile;
0060 private File outputTempFile;
0061 private FileOutputStream outputFileStream;
0062 private FileChannel outputFileChannel;
0063 private BufferedOutputStream outputBufferedFileStream;
0064
0065 public LocalDiskShuffleMapOutputWriter(
0066 int shuffleId,
0067 long mapId,
0068 int numPartitions,
0069 IndexShuffleBlockResolver blockResolver,
0070 SparkConf sparkConf) {
0071 this.shuffleId = shuffleId;
0072 this.mapId = mapId;
0073 this.blockResolver = blockResolver;
0074 this.bufferSize =
0075 (int) (long) sparkConf.get(
0076 package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
0077 this.partitionLengths = new long[numPartitions];
0078 this.outputFile = blockResolver.getDataFile(shuffleId, mapId);
0079 this.outputTempFile = null;
0080 }
0081
0082 @Override
0083 public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException {
0084 if (reducePartitionId <= lastPartitionId) {
0085 throw new IllegalArgumentException("Partitions should be requested in increasing order.");
0086 }
0087 lastPartitionId = reducePartitionId;
0088 if (outputTempFile == null) {
0089 outputTempFile = Utils.tempFileWith(outputFile);
0090 }
0091 if (outputFileChannel != null) {
0092 currChannelPosition = outputFileChannel.position();
0093 } else {
0094 currChannelPosition = 0L;
0095 }
0096 return new LocalDiskShufflePartitionWriter(reducePartitionId);
0097 }
0098
0099 @Override
0100 public long[] commitAllPartitions() throws IOException {
0101
0102
0103
0104
0105 if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) {
0106 throw new IOException(
0107 "Current position " + outputFileChannel.position() + " does not equal expected " +
0108 "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " +
0109 " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " +
0110 "to unexpected behavior when using transferTo. You can set " +
0111 "spark.file.transferTo=false to disable this NIO feature.");
0112 }
0113 cleanUp();
0114 File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
0115 blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
0116 return partitionLengths;
0117 }
0118
0119 @Override
0120 public void abort(Throwable error) throws IOException {
0121 cleanUp();
0122 if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) {
0123 log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
0124 }
0125 }
0126
0127 private void cleanUp() throws IOException {
0128 if (outputBufferedFileStream != null) {
0129 outputBufferedFileStream.close();
0130 }
0131 if (outputFileChannel != null) {
0132 outputFileChannel.close();
0133 }
0134 if (outputFileStream != null) {
0135 outputFileStream.close();
0136 }
0137 }
0138
0139 private void initStream() throws IOException {
0140 if (outputFileStream == null) {
0141 outputFileStream = new FileOutputStream(outputTempFile, true);
0142 }
0143 if (outputBufferedFileStream == null) {
0144 outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize);
0145 }
0146 }
0147
0148 private void initChannel() throws IOException {
0149
0150
0151 if (outputFileChannel == null) {
0152 outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel();
0153 }
0154 }
0155
0156 private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter {
0157
0158 private final int partitionId;
0159 private PartitionWriterStream partStream = null;
0160 private PartitionWriterChannel partChannel = null;
0161
0162 private LocalDiskShufflePartitionWriter(int partitionId) {
0163 this.partitionId = partitionId;
0164 }
0165
0166 @Override
0167 public OutputStream openStream() throws IOException {
0168 if (partStream == null) {
0169 if (outputFileChannel != null) {
0170 throw new IllegalStateException("Requested an output channel for a previous write but" +
0171 " now an output stream has been requested. Should not be using both channels" +
0172 " and streams to write.");
0173 }
0174 initStream();
0175 partStream = new PartitionWriterStream(partitionId);
0176 }
0177 return partStream;
0178 }
0179
0180 @Override
0181 public Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
0182 if (partChannel == null) {
0183 if (partStream != null) {
0184 throw new IllegalStateException("Requested an output stream for a previous write but" +
0185 " now an output channel has been requested. Should not be using both channels" +
0186 " and streams to write.");
0187 }
0188 initChannel();
0189 partChannel = new PartitionWriterChannel(partitionId);
0190 }
0191 return Optional.of(partChannel);
0192 }
0193
0194 @Override
0195 public long getNumBytesWritten() {
0196 if (partChannel != null) {
0197 try {
0198 return partChannel.getCount();
0199 } catch (IOException e) {
0200 throw new RuntimeException(e);
0201 }
0202 } else if (partStream != null) {
0203 return partStream.getCount();
0204 } else {
0205
0206 return 0;
0207 }
0208 }
0209 }
0210
0211 private class PartitionWriterStream extends OutputStream {
0212 private final int partitionId;
0213 private int count = 0;
0214 private boolean isClosed = false;
0215
0216 PartitionWriterStream(int partitionId) {
0217 this.partitionId = partitionId;
0218 }
0219
0220 public int getCount() {
0221 return count;
0222 }
0223
0224 @Override
0225 public void write(int b) throws IOException {
0226 verifyNotClosed();
0227 outputBufferedFileStream.write(b);
0228 count++;
0229 }
0230
0231 @Override
0232 public void write(byte[] buf, int pos, int length) throws IOException {
0233 verifyNotClosed();
0234 outputBufferedFileStream.write(buf, pos, length);
0235 count += length;
0236 }
0237
0238 @Override
0239 public void close() {
0240 isClosed = true;
0241 partitionLengths[partitionId] = count;
0242 bytesWrittenToMergedFile += count;
0243 }
0244
0245 private void verifyNotClosed() {
0246 if (isClosed) {
0247 throw new IllegalStateException("Attempting to write to a closed block output stream.");
0248 }
0249 }
0250 }
0251
0252 private class PartitionWriterChannel implements WritableByteChannelWrapper {
0253
0254 private final int partitionId;
0255
0256 PartitionWriterChannel(int partitionId) {
0257 this.partitionId = partitionId;
0258 }
0259
0260 public long getCount() throws IOException {
0261 long writtenPosition = outputFileChannel.position();
0262 return writtenPosition - currChannelPosition;
0263 }
0264
0265 @Override
0266 public WritableByteChannel channel() {
0267 return outputFileChannel;
0268 }
0269
0270 @Override
0271 public void close() throws IOException {
0272 partitionLengths[partitionId] = getCount();
0273 bytesWrittenToMergedFile += partitionLengths[partitionId];
0274 }
0275 }
0276 }