0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.util.sketch;
0019
0020 import java.io.*;
0021 import java.util.Arrays;
0022 import java.util.Random;
0023
0024 class CountMinSketchImpl extends CountMinSketch implements Serializable {
0025 private static final long PRIME_MODULUS = (1L << 31) - 1;
0026
0027 private int depth;
0028 private int width;
0029 private long[][] table;
0030 private long[] hashA;
0031 private long totalCount;
0032 private double eps;
0033 private double confidence;
0034
0035 private CountMinSketchImpl() {}
0036
0037 CountMinSketchImpl(int depth, int width, int seed) {
0038 if (depth <= 0 || width <= 0) {
0039 throw new IllegalArgumentException("Depth and width must be both positive");
0040 }
0041
0042 this.depth = depth;
0043 this.width = width;
0044 this.eps = 2.0 / width;
0045 this.confidence = 1 - 1 / Math.pow(2, depth);
0046 initTablesWith(depth, width, seed);
0047 }
0048
0049 CountMinSketchImpl(double eps, double confidence, int seed) {
0050 if (eps <= 0D) {
0051 throw new IllegalArgumentException("Relative error must be positive");
0052 }
0053
0054 if (confidence <= 0D || confidence >= 1D) {
0055 throw new IllegalArgumentException("Confidence must be within range (0.0, 1.0)");
0056 }
0057
0058
0059
0060 this.eps = eps;
0061 this.confidence = confidence;
0062 this.width = (int) Math.ceil(2 / eps);
0063 this.depth = (int) Math.ceil(-Math.log1p(-confidence) / Math.log(2));
0064 initTablesWith(depth, width, seed);
0065 }
0066
0067 @Override
0068 public boolean equals(Object other) {
0069 if (other == this) {
0070 return true;
0071 }
0072
0073 if (other == null || !(other instanceof CountMinSketchImpl)) {
0074 return false;
0075 }
0076
0077 CountMinSketchImpl that = (CountMinSketchImpl) other;
0078
0079 return
0080 this.depth == that.depth &&
0081 this.width == that.width &&
0082 this.totalCount == that.totalCount &&
0083 Arrays.equals(this.hashA, that.hashA) &&
0084 Arrays.deepEquals(this.table, that.table);
0085 }
0086
0087 @Override
0088 public int hashCode() {
0089 int hash = depth;
0090
0091 hash = hash * 31 + width;
0092 hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32));
0093 hash = hash * 31 + Arrays.hashCode(hashA);
0094 hash = hash * 31 + Arrays.deepHashCode(table);
0095
0096 return hash;
0097 }
0098
0099 private void initTablesWith(int depth, int width, int seed) {
0100 this.table = new long[depth][width];
0101 this.hashA = new long[depth];
0102 Random r = new Random(seed);
0103
0104
0105
0106
0107
0108
0109 for (int i = 0; i < depth; ++i) {
0110 hashA[i] = r.nextInt(Integer.MAX_VALUE);
0111 }
0112 }
0113
0114 @Override
0115 public double relativeError() {
0116 return eps;
0117 }
0118
0119 @Override
0120 public double confidence() {
0121 return confidence;
0122 }
0123
0124 @Override
0125 public int depth() {
0126 return depth;
0127 }
0128
0129 @Override
0130 public int width() {
0131 return width;
0132 }
0133
0134 @Override
0135 public long totalCount() {
0136 return totalCount;
0137 }
0138
0139 @Override
0140 public void add(Object item) {
0141 add(item, 1);
0142 }
0143
0144 @Override
0145 public void add(Object item, long count) {
0146 if (item instanceof String) {
0147 addString((String) item, count);
0148 } else if (item instanceof byte[]) {
0149 addBinary((byte[]) item, count);
0150 } else {
0151 addLong(Utils.integralToLong(item), count);
0152 }
0153 }
0154
0155 @Override
0156 public void addString(String item) {
0157 addString(item, 1);
0158 }
0159
0160 @Override
0161 public void addString(String item, long count) {
0162 addBinary(Utils.getBytesFromUTF8String(item), count);
0163 }
0164
0165 @Override
0166 public void addLong(long item) {
0167 addLong(item, 1);
0168 }
0169
0170 @Override
0171 public void addLong(long item, long count) {
0172 if (count < 0) {
0173 throw new IllegalArgumentException("Negative increments not implemented");
0174 }
0175
0176 for (int i = 0; i < depth; ++i) {
0177 table[i][hash(item, i)] += count;
0178 }
0179
0180 totalCount += count;
0181 }
0182
0183 @Override
0184 public void addBinary(byte[] item) {
0185 addBinary(item, 1);
0186 }
0187
0188 @Override
0189 public void addBinary(byte[] item, long count) {
0190 if (count < 0) {
0191 throw new IllegalArgumentException("Negative increments not implemented");
0192 }
0193
0194 int[] buckets = getHashBuckets(item, depth, width);
0195
0196 for (int i = 0; i < depth; ++i) {
0197 table[i][buckets[i]] += count;
0198 }
0199
0200 totalCount += count;
0201 }
0202
0203 private int hash(long item, int count) {
0204 long hash = hashA[count] * item;
0205
0206
0207
0208 hash += hash >> 32;
0209 hash &= PRIME_MODULUS;
0210
0211 return ((int) hash) % width;
0212 }
0213
0214 private static int[] getHashBuckets(String key, int hashCount, int max) {
0215 return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
0216 }
0217
0218 private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
0219 int[] result = new int[hashCount];
0220 int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
0221 int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1);
0222 for (int i = 0; i < hashCount; i++) {
0223 result[i] = Math.abs((hash1 + i * hash2) % max);
0224 }
0225 return result;
0226 }
0227
0228 @Override
0229 public long estimateCount(Object item) {
0230 if (item instanceof String) {
0231 return estimateCountForStringItem((String) item);
0232 } else if (item instanceof byte[]) {
0233 return estimateCountForBinaryItem((byte[]) item);
0234 } else {
0235 return estimateCountForLongItem(Utils.integralToLong(item));
0236 }
0237 }
0238
0239 private long estimateCountForLongItem(long item) {
0240 long res = Long.MAX_VALUE;
0241 for (int i = 0; i < depth; ++i) {
0242 res = Math.min(res, table[i][hash(item, i)]);
0243 }
0244 return res;
0245 }
0246
0247 private long estimateCountForStringItem(String item) {
0248 long res = Long.MAX_VALUE;
0249 int[] buckets = getHashBuckets(item, depth, width);
0250 for (int i = 0; i < depth; ++i) {
0251 res = Math.min(res, table[i][buckets[i]]);
0252 }
0253 return res;
0254 }
0255
0256 private long estimateCountForBinaryItem(byte[] item) {
0257 long res = Long.MAX_VALUE;
0258 int[] buckets = getHashBuckets(item, depth, width);
0259 for (int i = 0; i < depth; ++i) {
0260 res = Math.min(res, table[i][buckets[i]]);
0261 }
0262 return res;
0263 }
0264
0265 @Override
0266 public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException {
0267 if (other == null) {
0268 throw new IncompatibleMergeException("Cannot merge null estimator");
0269 }
0270
0271 if (!(other instanceof CountMinSketchImpl)) {
0272 throw new IncompatibleMergeException(
0273 "Cannot merge estimator of class " + other.getClass().getName()
0274 );
0275 }
0276
0277 CountMinSketchImpl that = (CountMinSketchImpl) other;
0278
0279 if (this.depth != that.depth) {
0280 throw new IncompatibleMergeException("Cannot merge estimators of different depth");
0281 }
0282
0283 if (this.width != that.width) {
0284 throw new IncompatibleMergeException("Cannot merge estimators of different width");
0285 }
0286
0287 if (!Arrays.equals(this.hashA, that.hashA)) {
0288 throw new IncompatibleMergeException("Cannot merge estimators of different seed");
0289 }
0290
0291 for (int i = 0; i < this.table.length; ++i) {
0292 for (int j = 0; j < this.table[i].length; ++j) {
0293 this.table[i][j] = this.table[i][j] + that.table[i][j];
0294 }
0295 }
0296
0297 this.totalCount += that.totalCount;
0298
0299 return this;
0300 }
0301
0302 @Override
0303 public void writeTo(OutputStream out) throws IOException {
0304 DataOutputStream dos = new DataOutputStream(out);
0305
0306 dos.writeInt(Version.V1.getVersionNumber());
0307
0308 dos.writeLong(this.totalCount);
0309 dos.writeInt(this.depth);
0310 dos.writeInt(this.width);
0311
0312 for (int i = 0; i < this.depth; ++i) {
0313 dos.writeLong(this.hashA[i]);
0314 }
0315
0316 for (int i = 0; i < this.depth; ++i) {
0317 for (int j = 0; j < this.width; ++j) {
0318 dos.writeLong(table[i][j]);
0319 }
0320 }
0321 }
0322
0323 @Override
0324 public byte[] toByteArray() throws IOException {
0325 try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
0326 writeTo(out);
0327 return out.toByteArray();
0328 }
0329 }
0330
0331 public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
0332 CountMinSketchImpl sketch = new CountMinSketchImpl();
0333 sketch.readFrom0(in);
0334 return sketch;
0335 }
0336
0337 private void readFrom0(InputStream in) throws IOException {
0338 DataInputStream dis = new DataInputStream(in);
0339
0340 int version = dis.readInt();
0341 if (version != Version.V1.getVersionNumber()) {
0342 throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
0343 }
0344
0345 this.totalCount = dis.readLong();
0346 this.depth = dis.readInt();
0347 this.width = dis.readInt();
0348 this.eps = 2.0 / width;
0349 this.confidence = 1 - 1 / Math.pow(2, depth);
0350
0351 this.hashA = new long[depth];
0352 for (int i = 0; i < depth; ++i) {
0353 this.hashA[i] = dis.readLong();
0354 }
0355
0356 this.table = new long[depth][width];
0357 for (int i = 0; i < depth; ++i) {
0358 for (int j = 0; j < width; ++j) {
0359 this.table[i][j] = dis.readLong();
0360 }
0361 }
0362 }
0363
0364 private void writeObject(ObjectOutputStream out) throws IOException {
0365 this.writeTo(out);
0366 }
0367
0368 private void readObject(ObjectInputStream in) throws IOException {
0369 this.readFrom0(in);
0370 }
0371 }