0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020 import copy
0021 import math
0022
0023 try:
0024 from numpy import maximum, minimum, sqrt
0025 except ImportError:
0026 maximum = max
0027 minimum = min
0028 sqrt = math.sqrt
0029
0030
0031 class StatCounter(object):
0032
0033 def __init__(self, values=None):
0034 if values is None:
0035 values = list()
0036 self.n = 0
0037 self.mu = 0.0
0038 self.m2 = 0.0
0039 self.maxValue = float("-inf")
0040 self.minValue = float("inf")
0041
0042 for v in values:
0043 self.merge(v)
0044
0045
0046 def merge(self, value):
0047 delta = value - self.mu
0048 self.n += 1
0049 self.mu += delta / self.n
0050 self.m2 += delta * (value - self.mu)
0051 self.maxValue = maximum(self.maxValue, value)
0052 self.minValue = minimum(self.minValue, value)
0053
0054 return self
0055
0056
0057 def mergeStats(self, other):
0058 if not isinstance(other, StatCounter):
0059 raise Exception("Can only merge Statcounters!")
0060
0061 if other is self:
0062 self.merge(copy.deepcopy(other))
0063 else:
0064 if self.n == 0:
0065 self.mu = other.mu
0066 self.m2 = other.m2
0067 self.n = other.n
0068 self.maxValue = other.maxValue
0069 self.minValue = other.minValue
0070
0071 elif other.n != 0:
0072 delta = other.mu - self.mu
0073 if other.n * 10 < self.n:
0074 self.mu = self.mu + (delta * other.n) / (self.n + other.n)
0075 elif self.n * 10 < other.n:
0076 self.mu = other.mu - (delta * self.n) / (self.n + other.n)
0077 else:
0078 self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
0079
0080 self.maxValue = maximum(self.maxValue, other.maxValue)
0081 self.minValue = minimum(self.minValue, other.minValue)
0082
0083 self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
0084 self.n += other.n
0085 return self
0086
0087
0088 def copy(self):
0089 return copy.deepcopy(self)
0090
0091 def count(self):
0092 return int(self.n)
0093
0094 def mean(self):
0095 return self.mu
0096
0097 def sum(self):
0098 return self.n * self.mu
0099
0100 def min(self):
0101 return self.minValue
0102
0103 def max(self):
0104 return self.maxValue
0105
0106
0107 def variance(self):
0108 if self.n == 0:
0109 return float('nan')
0110 else:
0111 return self.m2 / self.n
0112
0113
0114
0115
0116
0117 def sampleVariance(self):
0118 if self.n <= 1:
0119 return float('nan')
0120 else:
0121 return self.m2 / (self.n - 1)
0122
0123
0124 def stdev(self):
0125 return sqrt(self.variance())
0126
0127
0128
0129
0130
0131 def sampleStdev(self):
0132 return sqrt(self.sampleVariance())
0133
0134 def asDict(self, sample=False):
0135 """Returns the :class:`StatCounter` members as a ``dict``.
0136
0137 >>> sc.parallelize([1., 2., 3., 4.]).stats().asDict()
0138 {'count': 4L,
0139 'max': 4.0,
0140 'mean': 2.5,
0141 'min': 1.0,
0142 'stdev': 1.2909944487358056,
0143 'sum': 10.0,
0144 'variance': 1.6666666666666667}
0145 """
0146 return {
0147 'count': self.count(),
0148 'mean': self.mean(),
0149 'sum': self.sum(),
0150 'min': self.min(),
0151 'max': self.max(),
0152 'stdev': self.stdev() if sample else self.sampleStdev(),
0153 'variance': self.variance() if sample else self.sampleVariance()
0154 }
0155
0156 def __repr__(self):
0157 return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
0158 (self.count(), self.mean(), self.stdev(), self.max(), self.min()))