0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.sql.execution;
0019
0020 import java.io.IOException;
0021 import java.util.function.Supplier;
0022
0023 import scala.collection.Iterator;
0024 import scala.math.Ordering;
0025
0026 import com.google.common.annotations.VisibleForTesting;
0027
0028 import org.apache.spark.SparkEnv;
0029 import org.apache.spark.TaskContext;
0030 import org.apache.spark.internal.config.package$;
0031 import org.apache.spark.sql.catalyst.InternalRow;
0032 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
0033 import org.apache.spark.sql.types.StructType;
0034 import org.apache.spark.unsafe.Platform;
0035 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
0036 import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
0037 import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
0038 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
0039
0040 public final class UnsafeExternalRowSorter {
0041
0042
0043
0044
0045
0046 private int testSpillFrequency = 0;
0047
0048 private long numRowsInserted = 0;
0049
0050 private final StructType schema;
0051 private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
0052 private final UnsafeExternalSorter sorter;
0053
0054
0055
0056
0057
0058 private boolean isReleased = false;
0059
0060 public abstract static class PrefixComputer {
0061
0062 public static class Prefix {
0063
0064 public long value;
0065
0066
0067 public boolean isNull;
0068 }
0069
0070
0071
0072
0073
0074 public abstract Prefix computePrefix(InternalRow row);
0075 }
0076
0077 public static UnsafeExternalRowSorter createWithRecordComparator(
0078 StructType schema,
0079 Supplier<RecordComparator> recordComparatorSupplier,
0080 PrefixComparator prefixComparator,
0081 UnsafeExternalRowSorter.PrefixComputer prefixComputer,
0082 long pageSizeBytes,
0083 boolean canUseRadixSort) throws IOException {
0084 return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
0085 prefixComputer, pageSizeBytes, canUseRadixSort);
0086 }
0087
0088 public static UnsafeExternalRowSorter create(
0089 StructType schema,
0090 Ordering<InternalRow> ordering,
0091 PrefixComparator prefixComparator,
0092 UnsafeExternalRowSorter.PrefixComputer prefixComputer,
0093 long pageSizeBytes,
0094 boolean canUseRadixSort) throws IOException {
0095 Supplier<RecordComparator> recordComparatorSupplier =
0096 () -> new RowComparator(ordering, schema.length());
0097 return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
0098 prefixComputer, pageSizeBytes, canUseRadixSort);
0099 }
0100
0101 private UnsafeExternalRowSorter(
0102 StructType schema,
0103 Supplier<RecordComparator> recordComparatorSupplier,
0104 PrefixComparator prefixComparator,
0105 UnsafeExternalRowSorter.PrefixComputer prefixComputer,
0106 long pageSizeBytes,
0107 boolean canUseRadixSort) {
0108 this.schema = schema;
0109 this.prefixComputer = prefixComputer;
0110 final SparkEnv sparkEnv = SparkEnv.get();
0111 final TaskContext taskContext = TaskContext.get();
0112 sorter = UnsafeExternalSorter.create(
0113 taskContext.taskMemoryManager(),
0114 sparkEnv.blockManager(),
0115 sparkEnv.serializerManager(),
0116 taskContext,
0117 recordComparatorSupplier,
0118 prefixComparator,
0119 (int) (long) sparkEnv.conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
0120 pageSizeBytes,
0121 (int) SparkEnv.get().conf().get(
0122 package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()),
0123 canUseRadixSort
0124 );
0125 }
0126
0127
0128
0129
0130 @VisibleForTesting
0131 void setTestSpillFrequency(int frequency) {
0132 assert frequency > 0 : "Frequency must be positive";
0133 testSpillFrequency = frequency;
0134 }
0135
0136 public void insertRow(UnsafeRow row) throws IOException {
0137 final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row);
0138 sorter.insertRecord(
0139 row.getBaseObject(),
0140 row.getBaseOffset(),
0141 row.getSizeInBytes(),
0142 prefix.value,
0143 prefix.isNull
0144 );
0145 numRowsInserted++;
0146 if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
0147 sorter.spill();
0148 }
0149 }
0150
0151
0152
0153
0154 public long getPeakMemoryUsage() {
0155 return sorter.getPeakMemoryUsedBytes();
0156 }
0157
0158
0159
0160
0161 public long getSortTimeNanos() {
0162 return sorter.getSortTimeNanos();
0163 }
0164
0165 public void cleanupResources() {
0166 isReleased = true;
0167 sorter.cleanupResources();
0168 }
0169
0170 public Iterator<InternalRow> sort() throws IOException {
0171 try {
0172 final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
0173 if (!sortedIterator.hasNext()) {
0174
0175
0176 cleanupResources();
0177 }
0178 return new RowIterator() {
0179
0180 private final int numFields = schema.length();
0181 private UnsafeRow row = new UnsafeRow(numFields);
0182
0183 @Override
0184 public boolean advanceNext() {
0185 try {
0186 if (!isReleased && sortedIterator.hasNext()) {
0187 sortedIterator.loadNext();
0188 row.pointTo(
0189 sortedIterator.getBaseObject(),
0190 sortedIterator.getBaseOffset(),
0191 sortedIterator.getRecordLength());
0192
0193
0194
0195
0196 if (!sortedIterator.hasNext()) {
0197 row = row.copy();
0198 cleanupResources();
0199 }
0200 return true;
0201 } else {
0202 row = null;
0203 return false;
0204 }
0205 } catch (IOException e) {
0206 cleanupResources();
0207
0208
0209 Platform.throwException(e);
0210 }
0211 throw new RuntimeException("Exception should have been re-thrown in next()");
0212 }
0213
0214 @Override
0215 public UnsafeRow getRow() { return row; }
0216
0217 }.toScala();
0218 } catch (IOException e) {
0219 cleanupResources();
0220 throw e;
0221 }
0222 }
0223
0224 public Iterator<InternalRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
0225 while (inputIterator.hasNext()) {
0226 insertRow(inputIterator.next());
0227 }
0228 return sort();
0229 }
0230
0231 private static final class RowComparator extends RecordComparator {
0232 private final Ordering<InternalRow> ordering;
0233 private final UnsafeRow row1;
0234 private final UnsafeRow row2;
0235
0236 RowComparator(Ordering<InternalRow> ordering, int numFields) {
0237 this.row1 = new UnsafeRow(numFields);
0238 this.row2 = new UnsafeRow(numFields);
0239 this.ordering = ordering;
0240 }
0241
0242 @Override
0243 public int compare(
0244 Object baseObj1,
0245 long baseOff1,
0246 int baseLen1,
0247 Object baseObj2,
0248 long baseOff2,
0249 int baseLen2) {
0250
0251
0252 row1.pointTo(baseObj1, baseOff1, 0);
0253 row2.pointTo(baseObj2, baseOff2, 0);
0254 return ordering.compare(row1, row2);
0255 }
0256 }
0257 }