0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 import random
0020 import math
0021
0022
0023 class RDDSamplerBase(object):
0024
0025 def __init__(self, withReplacement, seed=None):
0026 self._seed = seed if seed is not None else random.randint(0, sys.maxsize)
0027 self._withReplacement = withReplacement
0028 self._random = None
0029
0030 def initRandomGenerator(self, split):
0031 self._random = random.Random(self._seed ^ split)
0032
0033
0034 for _ in range(10):
0035 self._random.randint(0, 1)
0036
0037 def getUniformSample(self):
0038 return self._random.random()
0039
0040 def getPoissonSample(self, mean):
0041
0042
0043 if mean < 20.0:
0044
0045 l = math.exp(-mean)
0046 p = self._random.random()
0047 k = 0
0048 while p > l:
0049 k += 1
0050 p *= self._random.random()
0051 else:
0052
0053 p = self._random.expovariate(mean)
0054 k = 0
0055 while p < 1.0:
0056 k += 1
0057 p += self._random.expovariate(mean)
0058 return k
0059
0060 def func(self, split, iterator):
0061 raise NotImplementedError
0062
0063
0064 class RDDSampler(RDDSamplerBase):
0065
0066 def __init__(self, withReplacement, fraction, seed=None):
0067 RDDSamplerBase.__init__(self, withReplacement, seed)
0068 self._fraction = fraction
0069
0070 def func(self, split, iterator):
0071 self.initRandomGenerator(split)
0072 if self._withReplacement:
0073 for obj in iterator:
0074
0075
0076
0077 count = self.getPoissonSample(self._fraction)
0078 for _ in range(0, count):
0079 yield obj
0080 else:
0081 for obj in iterator:
0082 if self.getUniformSample() < self._fraction:
0083 yield obj
0084
0085
0086 class RDDRangeSampler(RDDSamplerBase):
0087
0088 def __init__(self, lowerBound, upperBound, seed=None):
0089 RDDSamplerBase.__init__(self, False, seed)
0090 self._lowerBound = lowerBound
0091 self._upperBound = upperBound
0092
0093 def func(self, split, iterator):
0094 self.initRandomGenerator(split)
0095 for obj in iterator:
0096 if self._lowerBound <= self.getUniformSample() < self._upperBound:
0097 yield obj
0098
0099
0100 class RDDStratifiedSampler(RDDSamplerBase):
0101
0102 def __init__(self, withReplacement, fractions, seed=None):
0103 RDDSamplerBase.__init__(self, withReplacement, seed)
0104 self._fractions = fractions
0105
0106 def func(self, split, iterator):
0107 self.initRandomGenerator(split)
0108 if self._withReplacement:
0109 for key, val in iterator:
0110
0111
0112
0113 count = self.getPoissonSample(self._fractions[key])
0114 for _ in range(0, count):
0115 yield key, val
0116 else:
0117 for key, val in iterator:
0118 if self.getUniformSample() < self._fractions[key]:
0119 yield key, val