0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.sql.execution.vectorized;
0019
0020 import java.util.Arrays;
0021
0022 import com.google.common.annotations.VisibleForTesting;
0023
0024 import org.apache.spark.sql.types.StructType;
0025
0026 import static org.apache.spark.sql.types.DataTypes.LongType;
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 public class AggregateHashMap {
0042
0043 private OnHeapColumnVector[] columnVectors;
0044 private MutableColumnarRow aggBufferRow;
0045 private int[] buckets;
0046 private int numBuckets;
0047 private int numRows = 0;
0048 private int maxSteps = 3;
0049
0050 private static int DEFAULT_CAPACITY = 1 << 16;
0051 private static double DEFAULT_LOAD_FACTOR = 0.25;
0052 private static int DEFAULT_MAX_STEPS = 3;
0053
0054 public AggregateHashMap(StructType schema, int capacity, double loadFactor, int maxSteps) {
0055
0056
0057 assert (schema.size() == 2 && schema.fields()[0].dataType() == LongType &&
0058 schema.fields()[1].dataType() == LongType);
0059
0060
0061 assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
0062
0063 this.maxSteps = maxSteps;
0064 numBuckets = (int) (capacity / loadFactor);
0065 columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
0066 aggBufferRow = new MutableColumnarRow(columnVectors);
0067 buckets = new int[numBuckets];
0068 Arrays.fill(buckets, -1);
0069 }
0070
0071 public AggregateHashMap(StructType schema) {
0072 this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS);
0073 }
0074
0075 public MutableColumnarRow findOrInsert(long key) {
0076 int idx = find(key);
0077 if (idx != -1 && buckets[idx] == -1) {
0078 columnVectors[0].putLong(numRows, key);
0079 columnVectors[1].putLong(numRows, 0);
0080 buckets[idx] = numRows++;
0081 }
0082 aggBufferRow.rowId = buckets[idx];
0083 return aggBufferRow;
0084 }
0085
0086 @VisibleForTesting
0087 public int find(long key) {
0088 long h = hash(key);
0089 int step = 0;
0090 int idx = (int) h & (numBuckets - 1);
0091 while (step < maxSteps) {
0092
0093 if (buckets[idx] == -1) {
0094 return idx;
0095 } else if (equals(idx, key)) {
0096 return idx;
0097 }
0098 idx = (idx + 1) & (numBuckets - 1);
0099 step++;
0100 }
0101
0102 return -1;
0103 }
0104
0105 private long hash(long key) {
0106 return key;
0107 }
0108
0109 private boolean equals(int idx, long key1) {
0110 return columnVectors[0].getLong(buckets[idx]) == key1;
0111 }
0112 }