0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 >>> from pyspark.context import SparkContext
0020 >>> sc = SparkContext('local', 'test')
0021 >>> a = sc.accumulator(1)
0022 >>> a.value
0023 1
0024 >>> a.value = 2
0025 >>> a.value
0026 2
0027 >>> a += 5
0028 >>> a.value
0029 7
0030
0031 >>> sc.accumulator(1.0).value
0032 1.0
0033
0034 >>> sc.accumulator(1j).value
0035 1j
0036
0037 >>> rdd = sc.parallelize([1,2,3])
0038 >>> def f(x):
0039 ... global a
0040 ... a += x
0041 >>> rdd.foreach(f)
0042 >>> a.value
0043 13
0044
0045 >>> b = sc.accumulator(0)
0046 >>> def g(x):
0047 ... b.add(x)
0048 >>> rdd.foreach(g)
0049 >>> b.value
0050 6
0051
0052 >>> from pyspark.accumulators import AccumulatorParam
0053 >>> class VectorAccumulatorParam(AccumulatorParam):
0054 ... def zero(self, value):
0055 ... return [0.0] * len(value)
0056 ... def addInPlace(self, val1, val2):
0057 ... for i in range(len(val1)):
0058 ... val1[i] += val2[i]
0059 ... return val1
0060 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
0061 >>> va.value
0062 [1.0, 2.0, 3.0]
0063 >>> def g(x):
0064 ... global va
0065 ... va += [x] * 3
0066 >>> rdd.foreach(g)
0067 >>> va.value
0068 [7.0, 8.0, 9.0]
0069
0070 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
0071 Traceback (most recent call last):
0072 ...
0073 Py4JJavaError:...
0074
0075 >>> def h(x):
0076 ... global a
0077 ... a.value = 7
0078 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
0079 Traceback (most recent call last):
0080 ...
0081 Py4JJavaError:...
0082
0083 >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
0084 Traceback (most recent call last):
0085 ...
0086 TypeError:...
0087 """
0088
0089 import sys
0090 import select
0091 import struct
0092 if sys.version < '3':
0093 import SocketServer
0094 else:
0095 import socketserver as SocketServer
0096 import threading
0097 from pyspark.serializers import read_int, PickleSerializer
0098
0099
0100 __all__ = ['Accumulator', 'AccumulatorParam']
0101
0102
0103 pickleSer = PickleSerializer()
0104
0105
0106
0107 _accumulatorRegistry = {}
0108
0109
0110 def _deserialize_accumulator(aid, zero_value, accum_param):
0111 from pyspark.accumulators import _accumulatorRegistry
0112
0113 if aid in _accumulatorRegistry:
0114 return _accumulatorRegistry[aid]
0115 else:
0116 accum = Accumulator(aid, zero_value, accum_param)
0117 accum._deserialized = True
0118 _accumulatorRegistry[aid] = accum
0119 return accum
0120
0121
0122 class Accumulator(object):
0123
0124 """
0125 A shared variable that can be accumulated, i.e., has a commutative and associative "add"
0126 operation. Worker tasks on a Spark cluster can add values to an Accumulator with the `+=`
0127 operator, but only the driver program is allowed to access its value, using `value`.
0128 Updates from the workers get propagated automatically to the driver program.
0129
0130 While :class:`SparkContext` supports accumulators for primitive data types like :class:`int` and
0131 :class:`float`, users can also define accumulators for custom types by providing a custom
0132 :class:`AccumulatorParam` object. Refer to the doctest of this module for an example.
0133 """
0134
0135 def __init__(self, aid, value, accum_param):
0136 """Create a new Accumulator with a given initial value and AccumulatorParam object"""
0137 from pyspark.accumulators import _accumulatorRegistry
0138 self.aid = aid
0139 self.accum_param = accum_param
0140 self._value = value
0141 self._deserialized = False
0142 _accumulatorRegistry[aid] = self
0143
0144 def __reduce__(self):
0145 """Custom serialization; saves the zero value from our AccumulatorParam"""
0146 param = self.accum_param
0147 return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
0148
0149 @property
0150 def value(self):
0151 """Get the accumulator's value; only usable in driver program"""
0152 if self._deserialized:
0153 raise Exception("Accumulator.value cannot be accessed inside tasks")
0154 return self._value
0155
0156 @value.setter
0157 def value(self, value):
0158 """Sets the accumulator's value; only usable in driver program"""
0159 if self._deserialized:
0160 raise Exception("Accumulator.value cannot be accessed inside tasks")
0161 self._value = value
0162
0163 def add(self, term):
0164 """Adds a term to this accumulator's value"""
0165 self._value = self.accum_param.addInPlace(self._value, term)
0166
0167 def __iadd__(self, term):
0168 """The += operator; adds a term to this accumulator's value"""
0169 self.add(term)
0170 return self
0171
0172 def __str__(self):
0173 return str(self._value)
0174
0175 def __repr__(self):
0176 return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
0177
0178
0179 class AccumulatorParam(object):
0180
0181 """
0182 Helper object that defines how to accumulate values of a given type.
0183 """
0184
0185 def zero(self, value):
0186 """
0187 Provide a "zero value" for the type, compatible in dimensions with the
0188 provided `value` (e.g., a zero vector)
0189 """
0190 raise NotImplementedError
0191
0192 def addInPlace(self, value1, value2):
0193 """
0194 Add two values of the accumulator's data type, returning a new value;
0195 for efficiency, can also update `value1` in place and return it.
0196 """
0197 raise NotImplementedError
0198
0199
0200 class AddingAccumulatorParam(AccumulatorParam):
0201
0202 """
0203 An AccumulatorParam that uses the + operators to add values. Designed for simple types
0204 such as integers, floats, and lists. Requires the zero value for the underlying type
0205 as a parameter.
0206 """
0207
0208 def __init__(self, zero_value):
0209 self.zero_value = zero_value
0210
0211 def zero(self, value):
0212 return self.zero_value
0213
0214 def addInPlace(self, value1, value2):
0215 value1 += value2
0216 return value1
0217
0218
0219
0220 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
0221 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
0222 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
0223
0224
0225 class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
0226
0227 """
0228 This handler will keep polling updates from the same socket until the
0229 server is shutdown.
0230 """
0231
0232 def handle(self):
0233 from pyspark.accumulators import _accumulatorRegistry
0234 auth_token = self.server.auth_token
0235
0236 def poll(func):
0237 while not self.server.server_shutdown:
0238
0239 r, _, _ = select.select([self.rfile], [], [], 1)
0240 if self.rfile in r:
0241 if func():
0242 break
0243
0244 def accum_updates():
0245 num_updates = read_int(self.rfile)
0246 for _ in range(num_updates):
0247 (aid, update) = pickleSer._read_with_length(self.rfile)
0248 _accumulatorRegistry[aid] += update
0249
0250 self.wfile.write(struct.pack("!b", 1))
0251 return False
0252
0253 def authenticate_and_accum_updates():
0254 received_token = self.rfile.read(len(auth_token))
0255 if isinstance(received_token, bytes):
0256 received_token = received_token.decode("utf-8")
0257 if (received_token == auth_token):
0258 accum_updates()
0259
0260 return True
0261 else:
0262 raise Exception(
0263 "The value of the provided token to the AccumulatorServer is not correct.")
0264
0265
0266 poll(authenticate_and_accum_updates)
0267
0268 poll(accum_updates)
0269
0270
0271 class AccumulatorServer(SocketServer.TCPServer):
0272
0273 def __init__(self, server_address, RequestHandlerClass, auth_token):
0274 SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
0275 self.auth_token = auth_token
0276
0277 """
0278 A simple TCP server that intercepts shutdown() in order to interrupt
0279 our continuous polling on the handler.
0280 """
0281 server_shutdown = False
0282
0283 def shutdown(self):
0284 self.server_shutdown = True
0285 SocketServer.TCPServer.shutdown(self)
0286 self.server_close()
0287
0288
0289 def _start_update_server(auth_token):
0290 """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
0291 server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)
0292 thread = threading.Thread(target=server.serve_forever)
0293 thread.daemon = True
0294 thread.start()
0295 return server
0296
0297 if __name__ == "__main__":
0298 import doctest
0299 (failure_count, test_count) = doctest.testmod()
0300 if failure_count:
0301 sys.exit(-1)