Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.sql.execution;
0019 
0020 import java.io.IOException;
0021 
0022 import org.apache.spark.SparkEnv;
0023 import org.apache.spark.TaskContext;
0024 import org.apache.spark.internal.config.package$;
0025 import org.apache.spark.sql.catalyst.InternalRow;
0026 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
0027 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
0028 import org.apache.spark.sql.types.StructField;
0029 import org.apache.spark.sql.types.StructType;
0030 import org.apache.spark.unsafe.KVIterator;
0031 import org.apache.spark.unsafe.Platform;
0032 import org.apache.spark.unsafe.map.BytesToBytesMap;
0033 
0034 /**
0035  * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
0036  *
0037  * This map supports a maximum of 2 billion keys.
0038  */
0039 public final class UnsafeFixedWidthAggregationMap {
0040 
0041   /**
0042    * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
0043    * map, we copy this buffer and use it as the value.
0044    */
0045   private final byte[] emptyAggregationBuffer;
0046 
0047   private final StructType aggregationBufferSchema;
0048 
0049   private final StructType groupingKeySchema;
0050 
0051   /**
0052    * Encodes grouping keys as UnsafeRows.
0053    */
0054   private final UnsafeProjection groupingKeyProjection;
0055 
0056   /**
0057    * A hashmap which maps from opaque bytearray keys to bytearray values.
0058    */
0059   private final BytesToBytesMap map;
0060 
0061   /**
0062    * Re-used pointer to the current aggregation buffer
0063    */
0064   private final UnsafeRow currentAggregationBuffer;
0065 
0066   /**
0067    * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
0068    *         schema, false otherwise.
0069    */
0070   public static boolean supportsAggregationBufferSchema(StructType schema) {
0071     for (StructField field: schema.fields()) {
0072       if (!UnsafeRow.isMutable(field.dataType())) {
0073         return false;
0074       }
0075     }
0076     return true;
0077   }
0078 
0079   /**
0080    * Create a new UnsafeFixedWidthAggregationMap.
0081    *
0082    * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
0083    * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
0084    * @param groupingKeySchema the schema of the grouping key, used for row conversion.
0085    * @param taskContext the current task context.
0086    * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
0087    * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
0088    */
0089   public UnsafeFixedWidthAggregationMap(
0090       InternalRow emptyAggregationBuffer,
0091       StructType aggregationBufferSchema,
0092       StructType groupingKeySchema,
0093       TaskContext taskContext,
0094       int initialCapacity,
0095       long pageSizeBytes) {
0096     this.aggregationBufferSchema = aggregationBufferSchema;
0097     this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
0098     this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
0099     this.groupingKeySchema = groupingKeySchema;
0100     this.map = new BytesToBytesMap(
0101       taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes);
0102 
0103     // Initialize the buffer for aggregation value
0104     final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
0105     this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
0106 
0107     // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
0108     // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
0109     // does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
0110     taskContext.addTaskCompletionListener(context -> {
0111       free();
0112     });
0113   }
0114 
0115   /**
0116    * Return the aggregation buffer for the current group. For efficiency, all calls to this method
0117    * return the same object. If additional memory could not be allocated, then this method will
0118    * signal an error by returning null.
0119    */
0120   public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
0121     final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
0122 
0123     return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
0124   }
0125 
0126   public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
0127     return getAggregationBufferFromUnsafeRow(key, key.hashCode());
0128   }
0129 
0130   public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
0131     // Probe our map using the serialized key
0132     final BytesToBytesMap.Location loc = map.lookup(
0133       key.getBaseObject(),
0134       key.getBaseOffset(),
0135       key.getSizeInBytes(),
0136       hash);
0137     if (!loc.isDefined()) {
0138       // This is the first time that we've seen this grouping key, so we'll insert a copy of the
0139       // empty aggregation buffer into the map:
0140       boolean putSucceeded = loc.append(
0141         key.getBaseObject(),
0142         key.getBaseOffset(),
0143         key.getSizeInBytes(),
0144         emptyAggregationBuffer,
0145         Platform.BYTE_ARRAY_OFFSET,
0146         emptyAggregationBuffer.length
0147       );
0148       if (!putSucceeded) {
0149         return null;
0150       }
0151     }
0152 
0153     // Reset the pointer to point to the value that we just stored or looked up:
0154     currentAggregationBuffer.pointTo(
0155       loc.getValueBase(),
0156       loc.getValueOffset(),
0157       loc.getValueLength()
0158     );
0159     return currentAggregationBuffer;
0160   }
0161 
0162   /**
0163    * Returns an iterator over the keys and values in this map. This uses destructive iterator of
0164    * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has
0165    * been called.
0166    *
0167    * For efficiency, each call returns the same object.
0168    */
0169   public KVIterator<UnsafeRow, UnsafeRow> iterator() {
0170     return new KVIterator<UnsafeRow, UnsafeRow>() {
0171 
0172       private final BytesToBytesMap.MapIterator mapLocationIterator =
0173         map.destructiveIterator();
0174       private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length());
0175       private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length());
0176 
0177       @Override
0178       public boolean next() {
0179         if (mapLocationIterator.hasNext()) {
0180           final BytesToBytesMap.Location loc = mapLocationIterator.next();
0181           key.pointTo(
0182             loc.getKeyBase(),
0183             loc.getKeyOffset(),
0184             loc.getKeyLength()
0185           );
0186           value.pointTo(
0187             loc.getValueBase(),
0188             loc.getValueOffset(),
0189             loc.getValueLength()
0190           );
0191           return true;
0192         } else {
0193           return false;
0194         }
0195       }
0196 
0197       @Override
0198       public UnsafeRow getKey() {
0199         return key;
0200       }
0201 
0202       @Override
0203       public UnsafeRow getValue() {
0204         return value;
0205       }
0206 
0207       @Override
0208       public void close() {
0209         // Do nothing.
0210       }
0211     };
0212   }
0213 
0214   /**
0215    * Return the peak memory used so far, in bytes.
0216    */
0217   public long getPeakMemoryUsedBytes() {
0218     return map.getPeakMemoryUsedBytes();
0219   }
0220 
0221   /**
0222    * Free the memory associated with this map. This is idempotent and can be called multiple times.
0223    */
0224   public void free() {
0225     map.free();
0226   }
0227 
0228   /**
0229    * Gets the average bucket list iterations per lookup in the underlying `BytesToBytesMap`.
0230    */
0231   public double getAvgHashProbeBucketListIterations() {
0232     return map.getAvgHashProbeBucketListIterations();
0233   }
0234 
0235   /**
0236    * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]]
0237    *
0238    * Note that the map will be reset for inserting new records, and the returned sorter can NOT be
0239    * used to insert records.
0240    */
0241   public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
0242     return new UnsafeKVExternalSorter(
0243       groupingKeySchema,
0244       aggregationBufferSchema,
0245       SparkEnv.get().blockManager(),
0246       SparkEnv.get().serializerManager(),
0247       map.getPageSizeBytes(),
0248       (int) SparkEnv.get().conf().get(
0249         package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()),
0250       map);
0251   }
0252 }