Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 import random
0020 import math
0021 
0022 
0023 class RDDSamplerBase(object):
0024 
0025     def __init__(self, withReplacement, seed=None):
0026         self._seed = seed if seed is not None else random.randint(0, sys.maxsize)
0027         self._withReplacement = withReplacement
0028         self._random = None
0029 
0030     def initRandomGenerator(self, split):
0031         self._random = random.Random(self._seed ^ split)
0032 
0033         # mixing because the initial seeds are close to each other
0034         for _ in range(10):
0035             self._random.randint(0, 1)
0036 
0037     def getUniformSample(self):
0038         return self._random.random()
0039 
0040     def getPoissonSample(self, mean):
0041         # Using Knuth's algorithm described in
0042         # http://en.wikipedia.org/wiki/Poisson_distribution
0043         if mean < 20.0:
0044             # one exp and k+1 random calls
0045             l = math.exp(-mean)
0046             p = self._random.random()
0047             k = 0
0048             while p > l:
0049                 k += 1
0050                 p *= self._random.random()
0051         else:
0052             # switch to the log domain, k+1 expovariate (random + log) calls
0053             p = self._random.expovariate(mean)
0054             k = 0
0055             while p < 1.0:
0056                 k += 1
0057                 p += self._random.expovariate(mean)
0058         return k
0059 
0060     def func(self, split, iterator):
0061         raise NotImplementedError
0062 
0063 
0064 class RDDSampler(RDDSamplerBase):
0065 
0066     def __init__(self, withReplacement, fraction, seed=None):
0067         RDDSamplerBase.__init__(self, withReplacement, seed)
0068         self._fraction = fraction
0069 
0070     def func(self, split, iterator):
0071         self.initRandomGenerator(split)
0072         if self._withReplacement:
0073             for obj in iterator:
0074                 # For large datasets, the expected number of occurrences of each element in
0075                 # a sample with replacement is Poisson(frac). We use that to get a count for
0076                 # each element.
0077                 count = self.getPoissonSample(self._fraction)
0078                 for _ in range(0, count):
0079                     yield obj
0080         else:
0081             for obj in iterator:
0082                 if self.getUniformSample() < self._fraction:
0083                     yield obj
0084 
0085 
0086 class RDDRangeSampler(RDDSamplerBase):
0087 
0088     def __init__(self, lowerBound, upperBound, seed=None):
0089         RDDSamplerBase.__init__(self, False, seed)
0090         self._lowerBound = lowerBound
0091         self._upperBound = upperBound
0092 
0093     def func(self, split, iterator):
0094         self.initRandomGenerator(split)
0095         for obj in iterator:
0096             if self._lowerBound <= self.getUniformSample() < self._upperBound:
0097                 yield obj
0098 
0099 
0100 class RDDStratifiedSampler(RDDSamplerBase):
0101 
0102     def __init__(self, withReplacement, fractions, seed=None):
0103         RDDSamplerBase.__init__(self, withReplacement, seed)
0104         self._fractions = fractions
0105 
0106     def func(self, split, iterator):
0107         self.initRandomGenerator(split)
0108         if self._withReplacement:
0109             for key, val in iterator:
0110                 # For large datasets, the expected number of occurrences of each element in
0111                 # a sample with replacement is Poisson(frac). We use that to get a count for
0112                 # each element.
0113                 count = self.getPoissonSample(self._fractions[key])
0114                 for _ in range(0, count):
0115                     yield key, val
0116         else:
0117             for key, val in iterator:
0118                 if self.getUniformSample() < self._fractions[key]:
0119                     yield key, val