Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.sql.execution;
0019 
0020 import java.io.IOException;
0021 import java.util.function.Supplier;
0022 
0023 import scala.collection.Iterator;
0024 import scala.math.Ordering;
0025 
0026 import com.google.common.annotations.VisibleForTesting;
0027 
0028 import org.apache.spark.SparkEnv;
0029 import org.apache.spark.TaskContext;
0030 import org.apache.spark.internal.config.package$;
0031 import org.apache.spark.sql.catalyst.InternalRow;
0032 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
0033 import org.apache.spark.sql.types.StructType;
0034 import org.apache.spark.unsafe.Platform;
0035 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
0036 import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
0037 import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
0038 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
0039 
0040 public final class UnsafeExternalRowSorter {
0041 
0042   /**
0043    * If positive, forces records to be spilled to disk at the given frequency (measured in numbers
0044    * of records). This is only intended to be used in tests.
0045    */
0046   private int testSpillFrequency = 0;
0047 
0048   private long numRowsInserted = 0;
0049 
0050   private final StructType schema;
0051   private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
0052   private final UnsafeExternalSorter sorter;
0053 
0054   // This flag makes sure the cleanupResource() has been called. After the cleanup work,
0055   // iterator.next should always return false. Downstream operator triggers the resource
0056   // cleanup while they found there's no need to keep the iterator any more.
0057   // See more details in SPARK-21492.
0058   private boolean isReleased = false;
0059 
0060   public abstract static class PrefixComputer {
0061 
0062     public static class Prefix {
0063       /** Key prefix value, or the null prefix value if isNull = true. **/
0064       public long value;
0065 
0066       /** Whether the key is null. */
0067       public boolean isNull;
0068     }
0069 
0070     /**
0071      * Computes prefix for the given row. For efficiency, the returned object may be reused in
0072      * further calls to a given PrefixComputer.
0073      */
0074     public abstract Prefix computePrefix(InternalRow row);
0075   }
0076 
0077   public static UnsafeExternalRowSorter createWithRecordComparator(
0078       StructType schema,
0079       Supplier<RecordComparator> recordComparatorSupplier,
0080       PrefixComparator prefixComparator,
0081       UnsafeExternalRowSorter.PrefixComputer prefixComputer,
0082       long pageSizeBytes,
0083       boolean canUseRadixSort) throws IOException {
0084     return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
0085       prefixComputer, pageSizeBytes, canUseRadixSort);
0086   }
0087 
0088   public static UnsafeExternalRowSorter create(
0089       StructType schema,
0090       Ordering<InternalRow> ordering,
0091       PrefixComparator prefixComparator,
0092       UnsafeExternalRowSorter.PrefixComputer prefixComputer,
0093       long pageSizeBytes,
0094       boolean canUseRadixSort) throws IOException {
0095     Supplier<RecordComparator> recordComparatorSupplier =
0096       () -> new RowComparator(ordering, schema.length());
0097     return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
0098       prefixComputer, pageSizeBytes, canUseRadixSort);
0099   }
0100 
0101   private UnsafeExternalRowSorter(
0102       StructType schema,
0103       Supplier<RecordComparator> recordComparatorSupplier,
0104       PrefixComparator prefixComparator,
0105       UnsafeExternalRowSorter.PrefixComputer prefixComputer,
0106       long pageSizeBytes,
0107       boolean canUseRadixSort) {
0108     this.schema = schema;
0109     this.prefixComputer = prefixComputer;
0110     final SparkEnv sparkEnv = SparkEnv.get();
0111     final TaskContext taskContext = TaskContext.get();
0112     sorter = UnsafeExternalSorter.create(
0113       taskContext.taskMemoryManager(),
0114       sparkEnv.blockManager(),
0115       sparkEnv.serializerManager(),
0116       taskContext,
0117       recordComparatorSupplier,
0118       prefixComparator,
0119       (int) (long) sparkEnv.conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
0120       pageSizeBytes,
0121       (int) SparkEnv.get().conf().get(
0122         package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()),
0123       canUseRadixSort
0124     );
0125   }
0126 
0127   /**
0128    * Forces spills to occur every `frequency` records. Only for use in tests.
0129    */
0130   @VisibleForTesting
0131   void setTestSpillFrequency(int frequency) {
0132     assert frequency > 0 : "Frequency must be positive";
0133     testSpillFrequency = frequency;
0134   }
0135 
0136   public void insertRow(UnsafeRow row) throws IOException {
0137     final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row);
0138     sorter.insertRecord(
0139       row.getBaseObject(),
0140       row.getBaseOffset(),
0141       row.getSizeInBytes(),
0142       prefix.value,
0143       prefix.isNull
0144     );
0145     numRowsInserted++;
0146     if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
0147       sorter.spill();
0148     }
0149   }
0150 
0151   /**
0152    * Return the peak memory used so far, in bytes.
0153    */
0154   public long getPeakMemoryUsage() {
0155     return sorter.getPeakMemoryUsedBytes();
0156   }
0157 
0158   /**
0159    * @return the total amount of time spent sorting data (in-memory only).
0160    */
0161   public long getSortTimeNanos() {
0162     return sorter.getSortTimeNanos();
0163   }
0164 
0165   public void cleanupResources() {
0166     isReleased = true;
0167     sorter.cleanupResources();
0168   }
0169 
0170   public Iterator<InternalRow> sort() throws IOException {
0171     try {
0172       final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
0173       if (!sortedIterator.hasNext()) {
0174         // Since we won't ever call next() on an empty iterator, we need to clean up resources
0175         // here in order to prevent memory leaks.
0176         cleanupResources();
0177       }
0178       return new RowIterator() {
0179 
0180         private final int numFields = schema.length();
0181         private UnsafeRow row = new UnsafeRow(numFields);
0182 
0183         @Override
0184         public boolean advanceNext() {
0185           try {
0186             if (!isReleased && sortedIterator.hasNext()) {
0187               sortedIterator.loadNext();
0188               row.pointTo(
0189                   sortedIterator.getBaseObject(),
0190                   sortedIterator.getBaseOffset(),
0191                   sortedIterator.getRecordLength());
0192               // Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug
0193               // when returning the last row from an iterator. For example, in
0194               // [[GroupedIterator]], we still use the last row after traversing the iterator
0195               // in `fetchNextGroupIterator`
0196               if (!sortedIterator.hasNext()) {
0197                 row = row.copy(); // so that we don't have dangling pointers to freed page
0198                 cleanupResources();
0199               }
0200               return true;
0201             } else {
0202               row = null; // so that we don't keep references to the base object
0203               return false;
0204             }
0205           } catch (IOException e) {
0206             cleanupResources();
0207             // Scala iterators don't declare any checked exceptions, so we need to use this hack
0208             // to re-throw the exception:
0209             Platform.throwException(e);
0210           }
0211           throw new RuntimeException("Exception should have been re-thrown in next()");
0212         }
0213 
0214         @Override
0215         public UnsafeRow getRow() { return row; }
0216 
0217       }.toScala();
0218     } catch (IOException e) {
0219       cleanupResources();
0220       throw e;
0221     }
0222   }
0223 
0224   public Iterator<InternalRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
0225     while (inputIterator.hasNext()) {
0226       insertRow(inputIterator.next());
0227     }
0228     return sort();
0229   }
0230 
0231   private static final class RowComparator extends RecordComparator {
0232     private final Ordering<InternalRow> ordering;
0233     private final UnsafeRow row1;
0234     private final UnsafeRow row2;
0235 
0236     RowComparator(Ordering<InternalRow> ordering, int numFields) {
0237       this.row1 = new UnsafeRow(numFields);
0238       this.row2 = new UnsafeRow(numFields);
0239       this.ordering = ordering;
0240     }
0241 
0242     @Override
0243     public int compare(
0244         Object baseObj1,
0245         long baseOff1,
0246         int baseLen1,
0247         Object baseObj2,
0248         long baseOff2,
0249         int baseLen2) {
0250       // Note that since ordering doesn't need the total length of the record, we just pass 0
0251       // into the row.
0252       row1.pointTo(baseObj1, baseOff1, 0);
0253       row2.pointTo(baseObj2, baseOff2, 0);
0254       return ordering.compare(row1, row2);
0255     }
0256   }
0257 }