0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.sql.execution;
0019
0020 import java.io.IOException;
0021
0022 import org.apache.spark.SparkEnv;
0023 import org.apache.spark.TaskContext;
0024 import org.apache.spark.internal.config.package$;
0025 import org.apache.spark.sql.catalyst.InternalRow;
0026 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
0027 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
0028 import org.apache.spark.sql.types.StructField;
0029 import org.apache.spark.sql.types.StructType;
0030 import org.apache.spark.unsafe.KVIterator;
0031 import org.apache.spark.unsafe.Platform;
0032 import org.apache.spark.unsafe.map.BytesToBytesMap;
0033
0034
0035
0036
0037
0038
0039 public final class UnsafeFixedWidthAggregationMap {
0040
0041
0042
0043
0044
0045 private final byte[] emptyAggregationBuffer;
0046
0047 private final StructType aggregationBufferSchema;
0048
0049 private final StructType groupingKeySchema;
0050
0051
0052
0053
0054 private final UnsafeProjection groupingKeyProjection;
0055
0056
0057
0058
0059 private final BytesToBytesMap map;
0060
0061
0062
0063
0064 private final UnsafeRow currentAggregationBuffer;
0065
0066
0067
0068
0069
0070 public static boolean supportsAggregationBufferSchema(StructType schema) {
0071 for (StructField field: schema.fields()) {
0072 if (!UnsafeRow.isMutable(field.dataType())) {
0073 return false;
0074 }
0075 }
0076 return true;
0077 }
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089 public UnsafeFixedWidthAggregationMap(
0090 InternalRow emptyAggregationBuffer,
0091 StructType aggregationBufferSchema,
0092 StructType groupingKeySchema,
0093 TaskContext taskContext,
0094 int initialCapacity,
0095 long pageSizeBytes) {
0096 this.aggregationBufferSchema = aggregationBufferSchema;
0097 this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
0098 this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
0099 this.groupingKeySchema = groupingKeySchema;
0100 this.map = new BytesToBytesMap(
0101 taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes);
0102
0103
0104 final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
0105 this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
0106
0107
0108
0109
0110 taskContext.addTaskCompletionListener(context -> {
0111 free();
0112 });
0113 }
0114
0115
0116
0117
0118
0119
0120 public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
0121 final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
0122
0123 return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
0124 }
0125
0126 public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
0127 return getAggregationBufferFromUnsafeRow(key, key.hashCode());
0128 }
0129
0130 public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
0131
0132 final BytesToBytesMap.Location loc = map.lookup(
0133 key.getBaseObject(),
0134 key.getBaseOffset(),
0135 key.getSizeInBytes(),
0136 hash);
0137 if (!loc.isDefined()) {
0138
0139
0140 boolean putSucceeded = loc.append(
0141 key.getBaseObject(),
0142 key.getBaseOffset(),
0143 key.getSizeInBytes(),
0144 emptyAggregationBuffer,
0145 Platform.BYTE_ARRAY_OFFSET,
0146 emptyAggregationBuffer.length
0147 );
0148 if (!putSucceeded) {
0149 return null;
0150 }
0151 }
0152
0153
0154 currentAggregationBuffer.pointTo(
0155 loc.getValueBase(),
0156 loc.getValueOffset(),
0157 loc.getValueLength()
0158 );
0159 return currentAggregationBuffer;
0160 }
0161
0162
0163
0164
0165
0166
0167
0168
0169 public KVIterator<UnsafeRow, UnsafeRow> iterator() {
0170 return new KVIterator<UnsafeRow, UnsafeRow>() {
0171
0172 private final BytesToBytesMap.MapIterator mapLocationIterator =
0173 map.destructiveIterator();
0174 private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length());
0175 private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length());
0176
0177 @Override
0178 public boolean next() {
0179 if (mapLocationIterator.hasNext()) {
0180 final BytesToBytesMap.Location loc = mapLocationIterator.next();
0181 key.pointTo(
0182 loc.getKeyBase(),
0183 loc.getKeyOffset(),
0184 loc.getKeyLength()
0185 );
0186 value.pointTo(
0187 loc.getValueBase(),
0188 loc.getValueOffset(),
0189 loc.getValueLength()
0190 );
0191 return true;
0192 } else {
0193 return false;
0194 }
0195 }
0196
0197 @Override
0198 public UnsafeRow getKey() {
0199 return key;
0200 }
0201
0202 @Override
0203 public UnsafeRow getValue() {
0204 return value;
0205 }
0206
0207 @Override
0208 public void close() {
0209
0210 }
0211 };
0212 }
0213
0214
0215
0216
0217 public long getPeakMemoryUsedBytes() {
0218 return map.getPeakMemoryUsedBytes();
0219 }
0220
0221
0222
0223
0224 public void free() {
0225 map.free();
0226 }
0227
0228
0229
0230
0231 public double getAvgHashProbeBucketListIterations() {
0232 return map.getAvgHashProbeBucketListIterations();
0233 }
0234
0235
0236
0237
0238
0239
0240
0241 public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
0242 return new UnsafeKVExternalSorter(
0243 groupingKeySchema,
0244 aggregationBufferSchema,
0245 SparkEnv.get().blockManager(),
0246 SparkEnv.get().serializerManager(),
0247 map.getPageSizeBytes(),
0248 (int) SparkEnv.get().conf().get(
0249 package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()),
0250 map);
0251 }
0252 }