Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # This file is ported from spark/util/StatCounter.scala
0019 
0020 import copy
0021 import math
0022 
0023 try:
0024     from numpy import maximum, minimum, sqrt
0025 except ImportError:
0026     maximum = max
0027     minimum = min
0028     sqrt = math.sqrt
0029 
0030 
0031 class StatCounter(object):
0032 
0033     def __init__(self, values=None):
0034         if values is None:
0035             values = list()
0036         self.n = 0    # Running count of our values
0037         self.mu = 0.0  # Running mean of our values
0038         self.m2 = 0.0  # Running variance numerator (sum of (x - mean)^2)
0039         self.maxValue = float("-inf")
0040         self.minValue = float("inf")
0041 
0042         for v in values:
0043             self.merge(v)
0044 
0045     # Add a value into this StatCounter, updating the internal statistics.
0046     def merge(self, value):
0047         delta = value - self.mu
0048         self.n += 1
0049         self.mu += delta / self.n
0050         self.m2 += delta * (value - self.mu)
0051         self.maxValue = maximum(self.maxValue, value)
0052         self.minValue = minimum(self.minValue, value)
0053 
0054         return self
0055 
0056     # Merge another StatCounter into this one, adding up the internal statistics.
0057     def mergeStats(self, other):
0058         if not isinstance(other, StatCounter):
0059             raise Exception("Can only merge Statcounters!")
0060 
0061         if other is self:  # reference equality holds
0062             self.merge(copy.deepcopy(other))  # Avoid overwriting fields in a weird order
0063         else:
0064             if self.n == 0:
0065                 self.mu = other.mu
0066                 self.m2 = other.m2
0067                 self.n = other.n
0068                 self.maxValue = other.maxValue
0069                 self.minValue = other.minValue
0070 
0071             elif other.n != 0:
0072                 delta = other.mu - self.mu
0073                 if other.n * 10 < self.n:
0074                     self.mu = self.mu + (delta * other.n) / (self.n + other.n)
0075                 elif self.n * 10 < other.n:
0076                     self.mu = other.mu - (delta * self.n) / (self.n + other.n)
0077                 else:
0078                     self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
0079 
0080                 self.maxValue = maximum(self.maxValue, other.maxValue)
0081                 self.minValue = minimum(self.minValue, other.minValue)
0082 
0083                 self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
0084                 self.n += other.n
0085         return self
0086 
0087     # Clone this StatCounter
0088     def copy(self):
0089         return copy.deepcopy(self)
0090 
0091     def count(self):
0092         return int(self.n)
0093 
0094     def mean(self):
0095         return self.mu
0096 
0097     def sum(self):
0098         return self.n * self.mu
0099 
0100     def min(self):
0101         return self.minValue
0102 
0103     def max(self):
0104         return self.maxValue
0105 
0106     # Return the variance of the values.
0107     def variance(self):
0108         if self.n == 0:
0109             return float('nan')
0110         else:
0111             return self.m2 / self.n
0112 
0113     #
0114     # Return the sample variance, which corrects for bias in estimating the variance by dividing
0115     # by N-1 instead of N.
0116     #
0117     def sampleVariance(self):
0118         if self.n <= 1:
0119             return float('nan')
0120         else:
0121             return self.m2 / (self.n - 1)
0122 
0123     # Return the standard deviation of the values.
0124     def stdev(self):
0125         return sqrt(self.variance())
0126 
0127     #
0128     # Return the sample standard deviation of the values, which corrects for bias in estimating the
0129     # variance by dividing by N-1 instead of N.
0130     #
0131     def sampleStdev(self):
0132         return sqrt(self.sampleVariance())
0133 
0134     def asDict(self, sample=False):
0135         """Returns the :class:`StatCounter` members as a ``dict``.
0136 
0137         >>> sc.parallelize([1., 2., 3., 4.]).stats().asDict()
0138         {'count': 4L,
0139          'max': 4.0,
0140          'mean': 2.5,
0141          'min': 1.0,
0142          'stdev': 1.2909944487358056,
0143          'sum': 10.0,
0144          'variance': 1.6666666666666667}
0145         """
0146         return {
0147             'count': self.count(),
0148             'mean': self.mean(),
0149             'sum': self.sum(),
0150             'min': self.min(),
0151             'max': self.max(),
0152             'stdev': self.stdev() if sample else self.sampleStdev(),
0153             'variance': self.variance() if sample else self.sampleVariance()
0154         }
0155 
0156     def __repr__(self):
0157         return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
0158                 (self.count(), self.mean(), self.stdev(), self.max(), self.min()))