Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.util.sketch;
0019 
0020 import java.io.*;
0021 import java.util.Arrays;
0022 import java.util.Random;
0023 
0024 class CountMinSketchImpl extends CountMinSketch implements Serializable {
0025   private static final long PRIME_MODULUS = (1L << 31) - 1;
0026 
0027   private int depth;
0028   private int width;
0029   private long[][] table;
0030   private long[] hashA;
0031   private long totalCount;
0032   private double eps;
0033   private double confidence;
0034 
0035   private CountMinSketchImpl() {}
0036 
0037   CountMinSketchImpl(int depth, int width, int seed) {
0038     if (depth <= 0 || width <= 0) {
0039       throw new IllegalArgumentException("Depth and width must be both positive");
0040     }
0041 
0042     this.depth = depth;
0043     this.width = width;
0044     this.eps = 2.0 / width;
0045     this.confidence = 1 - 1 / Math.pow(2, depth);
0046     initTablesWith(depth, width, seed);
0047   }
0048 
0049   CountMinSketchImpl(double eps, double confidence, int seed) {
0050     if (eps <= 0D) {
0051       throw new IllegalArgumentException("Relative error must be positive");
0052     }
0053 
0054     if (confidence <= 0D || confidence >= 1D) {
0055       throw new IllegalArgumentException("Confidence must be within range (0.0, 1.0)");
0056     }
0057 
0058     // 2/w = eps ; w = 2/eps
0059     // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
0060     this.eps = eps;
0061     this.confidence = confidence;
0062     this.width = (int) Math.ceil(2 / eps);
0063     this.depth = (int) Math.ceil(-Math.log1p(-confidence) / Math.log(2));
0064     initTablesWith(depth, width, seed);
0065   }
0066 
0067   @Override
0068   public boolean equals(Object other) {
0069     if (other == this) {
0070       return true;
0071     }
0072 
0073     if (other == null || !(other instanceof CountMinSketchImpl)) {
0074       return false;
0075     }
0076 
0077     CountMinSketchImpl that = (CountMinSketchImpl) other;
0078 
0079     return
0080       this.depth == that.depth &&
0081       this.width == that.width &&
0082       this.totalCount == that.totalCount &&
0083       Arrays.equals(this.hashA, that.hashA) &&
0084       Arrays.deepEquals(this.table, that.table);
0085   }
0086 
0087   @Override
0088   public int hashCode() {
0089     int hash = depth;
0090 
0091     hash = hash * 31 + width;
0092     hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32));
0093     hash = hash * 31 + Arrays.hashCode(hashA);
0094     hash = hash * 31 + Arrays.deepHashCode(table);
0095 
0096     return hash;
0097   }
0098 
0099   private void initTablesWith(int depth, int width, int seed) {
0100     this.table = new long[depth][width];
0101     this.hashA = new long[depth];
0102     Random r = new Random(seed);
0103     // We're using a linear hash functions
0104     // of the form (a*x+b) mod p.
0105     // a,b are chosen independently for each hash function.
0106     // However we can set b = 0 as all it does is shift the results
0107     // without compromising their uniformity or independence with
0108     // the other hashes.
0109     for (int i = 0; i < depth; ++i) {
0110       hashA[i] = r.nextInt(Integer.MAX_VALUE);
0111     }
0112   }
0113 
0114   @Override
0115   public double relativeError() {
0116     return eps;
0117   }
0118 
0119   @Override
0120   public double confidence() {
0121     return confidence;
0122   }
0123 
0124   @Override
0125   public int depth() {
0126     return depth;
0127   }
0128 
0129   @Override
0130   public int width() {
0131     return width;
0132   }
0133 
0134   @Override
0135   public long totalCount() {
0136     return totalCount;
0137   }
0138 
0139   @Override
0140   public void add(Object item) {
0141     add(item, 1);
0142   }
0143 
0144   @Override
0145   public void add(Object item, long count) {
0146     if (item instanceof String) {
0147       addString((String) item, count);
0148     } else if (item instanceof byte[]) {
0149       addBinary((byte[]) item, count);
0150     } else {
0151       addLong(Utils.integralToLong(item), count);
0152     }
0153   }
0154 
0155   @Override
0156   public void addString(String item) {
0157     addString(item, 1);
0158   }
0159 
0160   @Override
0161   public void addString(String item, long count) {
0162     addBinary(Utils.getBytesFromUTF8String(item), count);
0163   }
0164 
0165   @Override
0166   public void addLong(long item) {
0167     addLong(item, 1);
0168   }
0169 
0170   @Override
0171   public void addLong(long item, long count) {
0172     if (count < 0) {
0173       throw new IllegalArgumentException("Negative increments not implemented");
0174     }
0175 
0176     for (int i = 0; i < depth; ++i) {
0177       table[i][hash(item, i)] += count;
0178     }
0179 
0180     totalCount += count;
0181   }
0182 
0183   @Override
0184   public void addBinary(byte[] item) {
0185     addBinary(item, 1);
0186   }
0187 
0188   @Override
0189   public void addBinary(byte[] item, long count) {
0190     if (count < 0) {
0191       throw new IllegalArgumentException("Negative increments not implemented");
0192     }
0193 
0194     int[] buckets = getHashBuckets(item, depth, width);
0195 
0196     for (int i = 0; i < depth; ++i) {
0197       table[i][buckets[i]] += count;
0198     }
0199 
0200     totalCount += count;
0201   }
0202 
0203   private int hash(long item, int count) {
0204     long hash = hashA[count] * item;
0205     // A super fast way of computing x mod 2^p-1
0206     // See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
0207     // page 149, right after Proposition 7.
0208     hash += hash >> 32;
0209     hash &= PRIME_MODULUS;
0210     // Doing "%" after (int) conversion is ~2x faster than %'ing longs.
0211     return ((int) hash) % width;
0212   }
0213 
0214   private static int[] getHashBuckets(String key, int hashCount, int max) {
0215     return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
0216   }
0217 
0218   private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
0219     int[] result = new int[hashCount];
0220     int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
0221     int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1);
0222     for (int i = 0; i < hashCount; i++) {
0223       result[i] = Math.abs((hash1 + i * hash2) % max);
0224     }
0225     return result;
0226   }
0227 
0228   @Override
0229   public long estimateCount(Object item) {
0230     if (item instanceof String) {
0231       return estimateCountForStringItem((String) item);
0232     } else if (item instanceof byte[]) {
0233       return estimateCountForBinaryItem((byte[]) item);
0234     } else {
0235       return estimateCountForLongItem(Utils.integralToLong(item));
0236     }
0237   }
0238 
0239   private long estimateCountForLongItem(long item) {
0240     long res = Long.MAX_VALUE;
0241     for (int i = 0; i < depth; ++i) {
0242       res = Math.min(res, table[i][hash(item, i)]);
0243     }
0244     return res;
0245   }
0246 
0247   private long estimateCountForStringItem(String item) {
0248     long res = Long.MAX_VALUE;
0249     int[] buckets = getHashBuckets(item, depth, width);
0250     for (int i = 0; i < depth; ++i) {
0251       res = Math.min(res, table[i][buckets[i]]);
0252     }
0253     return res;
0254   }
0255 
0256   private long estimateCountForBinaryItem(byte[] item) {
0257     long res = Long.MAX_VALUE;
0258     int[] buckets = getHashBuckets(item, depth, width);
0259     for (int i = 0; i < depth; ++i) {
0260       res = Math.min(res, table[i][buckets[i]]);
0261     }
0262     return res;
0263   }
0264 
0265   @Override
0266   public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException {
0267     if (other == null) {
0268       throw new IncompatibleMergeException("Cannot merge null estimator");
0269     }
0270 
0271     if (!(other instanceof CountMinSketchImpl)) {
0272       throw new IncompatibleMergeException(
0273           "Cannot merge estimator of class " + other.getClass().getName()
0274       );
0275     }
0276 
0277     CountMinSketchImpl that = (CountMinSketchImpl) other;
0278 
0279     if (this.depth != that.depth) {
0280       throw new IncompatibleMergeException("Cannot merge estimators of different depth");
0281     }
0282 
0283     if (this.width != that.width) {
0284       throw new IncompatibleMergeException("Cannot merge estimators of different width");
0285     }
0286 
0287     if (!Arrays.equals(this.hashA, that.hashA)) {
0288       throw new IncompatibleMergeException("Cannot merge estimators of different seed");
0289     }
0290 
0291     for (int i = 0; i < this.table.length; ++i) {
0292       for (int j = 0; j < this.table[i].length; ++j) {
0293         this.table[i][j] = this.table[i][j] + that.table[i][j];
0294       }
0295     }
0296 
0297     this.totalCount += that.totalCount;
0298 
0299     return this;
0300   }
0301 
0302   @Override
0303   public void writeTo(OutputStream out) throws IOException {
0304     DataOutputStream dos = new DataOutputStream(out);
0305 
0306     dos.writeInt(Version.V1.getVersionNumber());
0307 
0308     dos.writeLong(this.totalCount);
0309     dos.writeInt(this.depth);
0310     dos.writeInt(this.width);
0311 
0312     for (int i = 0; i < this.depth; ++i) {
0313       dos.writeLong(this.hashA[i]);
0314     }
0315 
0316     for (int i = 0; i < this.depth; ++i) {
0317       for (int j = 0; j < this.width; ++j) {
0318         dos.writeLong(table[i][j]);
0319       }
0320     }
0321   }
0322 
0323   @Override
0324   public byte[] toByteArray() throws IOException {
0325     try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
0326       writeTo(out);
0327       return out.toByteArray();
0328     }
0329   }
0330 
0331   public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
0332     CountMinSketchImpl sketch = new CountMinSketchImpl();
0333     sketch.readFrom0(in);
0334     return sketch;
0335   }
0336 
0337   private void readFrom0(InputStream in) throws IOException {
0338     DataInputStream dis = new DataInputStream(in);
0339 
0340     int version = dis.readInt();
0341     if (version != Version.V1.getVersionNumber()) {
0342       throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
0343     }
0344 
0345     this.totalCount = dis.readLong();
0346     this.depth = dis.readInt();
0347     this.width = dis.readInt();
0348     this.eps = 2.0 / width;
0349     this.confidence = 1 - 1 / Math.pow(2, depth);
0350 
0351     this.hashA = new long[depth];
0352     for (int i = 0; i < depth; ++i) {
0353       this.hashA[i] = dis.readLong();
0354     }
0355 
0356     this.table = new long[depth][width];
0357     for (int i = 0; i < depth; ++i) {
0358       for (int j = 0; j < width; ++j) {
0359         this.table[i][j] = dis.readLong();
0360       }
0361     }
0362   }
0363 
0364   private void writeObject(ObjectOutputStream out) throws IOException {
0365     this.writeTo(out);
0366   }
0367 
0368   private void readObject(ObjectInputStream in) throws IOException {
0369     this.readFrom0(in);
0370   }
0371 }