Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 >>> from pyspark.context import SparkContext
0020 >>> sc = SparkContext('local', 'test')
0021 >>> a = sc.accumulator(1)
0022 >>> a.value
0023 1
0024 >>> a.value = 2
0025 >>> a.value
0026 2
0027 >>> a += 5
0028 >>> a.value
0029 7
0030 
0031 >>> sc.accumulator(1.0).value
0032 1.0
0033 
0034 >>> sc.accumulator(1j).value
0035 1j
0036 
0037 >>> rdd = sc.parallelize([1,2,3])
0038 >>> def f(x):
0039 ...     global a
0040 ...     a += x
0041 >>> rdd.foreach(f)
0042 >>> a.value
0043 13
0044 
0045 >>> b = sc.accumulator(0)
0046 >>> def g(x):
0047 ...     b.add(x)
0048 >>> rdd.foreach(g)
0049 >>> b.value
0050 6
0051 
0052 >>> from pyspark.accumulators import AccumulatorParam
0053 >>> class VectorAccumulatorParam(AccumulatorParam):
0054 ...     def zero(self, value):
0055 ...         return [0.0] * len(value)
0056 ...     def addInPlace(self, val1, val2):
0057 ...         for i in range(len(val1)):
0058 ...              val1[i] += val2[i]
0059 ...         return val1
0060 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
0061 >>> va.value
0062 [1.0, 2.0, 3.0]
0063 >>> def g(x):
0064 ...     global va
0065 ...     va += [x] * 3
0066 >>> rdd.foreach(g)
0067 >>> va.value
0068 [7.0, 8.0, 9.0]
0069 
0070 >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
0071 Traceback (most recent call last):
0072     ...
0073 Py4JJavaError:...
0074 
0075 >>> def h(x):
0076 ...     global a
0077 ...     a.value = 7
0078 >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
0079 Traceback (most recent call last):
0080     ...
0081 Py4JJavaError:...
0082 
0083 >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
0084 Traceback (most recent call last):
0085     ...
0086 TypeError:...
0087 """
0088 
0089 import sys
0090 import select
0091 import struct
0092 if sys.version < '3':
0093     import SocketServer
0094 else:
0095     import socketserver as SocketServer
0096 import threading
0097 from pyspark.serializers import read_int, PickleSerializer
0098 
0099 
0100 __all__ = ['Accumulator', 'AccumulatorParam']
0101 
0102 
0103 pickleSer = PickleSerializer()
0104 
0105 # Holds accumulators registered on the current machine, keyed by ID. This is then used to send
0106 # the local accumulator updates back to the driver program at the end of a task.
0107 _accumulatorRegistry = {}
0108 
0109 
0110 def _deserialize_accumulator(aid, zero_value, accum_param):
0111     from pyspark.accumulators import _accumulatorRegistry
0112     # If this certain accumulator was deserialized, don't overwrite it.
0113     if aid in _accumulatorRegistry:
0114         return _accumulatorRegistry[aid]
0115     else:
0116         accum = Accumulator(aid, zero_value, accum_param)
0117         accum._deserialized = True
0118         _accumulatorRegistry[aid] = accum
0119         return accum
0120 
0121 
0122 class Accumulator(object):
0123 
0124     """
0125     A shared variable that can be accumulated, i.e., has a commutative and associative "add"
0126     operation. Worker tasks on a Spark cluster can add values to an Accumulator with the `+=`
0127     operator, but only the driver program is allowed to access its value, using `value`.
0128     Updates from the workers get propagated automatically to the driver program.
0129 
0130     While :class:`SparkContext` supports accumulators for primitive data types like :class:`int` and
0131     :class:`float`, users can also define accumulators for custom types by providing a custom
0132     :class:`AccumulatorParam` object. Refer to the doctest of this module for an example.
0133     """
0134 
0135     def __init__(self, aid, value, accum_param):
0136         """Create a new Accumulator with a given initial value and AccumulatorParam object"""
0137         from pyspark.accumulators import _accumulatorRegistry
0138         self.aid = aid
0139         self.accum_param = accum_param
0140         self._value = value
0141         self._deserialized = False
0142         _accumulatorRegistry[aid] = self
0143 
0144     def __reduce__(self):
0145         """Custom serialization; saves the zero value from our AccumulatorParam"""
0146         param = self.accum_param
0147         return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
0148 
0149     @property
0150     def value(self):
0151         """Get the accumulator's value; only usable in driver program"""
0152         if self._deserialized:
0153             raise Exception("Accumulator.value cannot be accessed inside tasks")
0154         return self._value
0155 
0156     @value.setter
0157     def value(self, value):
0158         """Sets the accumulator's value; only usable in driver program"""
0159         if self._deserialized:
0160             raise Exception("Accumulator.value cannot be accessed inside tasks")
0161         self._value = value
0162 
0163     def add(self, term):
0164         """Adds a term to this accumulator's value"""
0165         self._value = self.accum_param.addInPlace(self._value, term)
0166 
0167     def __iadd__(self, term):
0168         """The += operator; adds a term to this accumulator's value"""
0169         self.add(term)
0170         return self
0171 
0172     def __str__(self):
0173         return str(self._value)
0174 
0175     def __repr__(self):
0176         return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
0177 
0178 
0179 class AccumulatorParam(object):
0180 
0181     """
0182     Helper object that defines how to accumulate values of a given type.
0183     """
0184 
0185     def zero(self, value):
0186         """
0187         Provide a "zero value" for the type, compatible in dimensions with the
0188         provided `value` (e.g., a zero vector)
0189         """
0190         raise NotImplementedError
0191 
0192     def addInPlace(self, value1, value2):
0193         """
0194         Add two values of the accumulator's data type, returning a new value;
0195         for efficiency, can also update `value1` in place and return it.
0196         """
0197         raise NotImplementedError
0198 
0199 
0200 class AddingAccumulatorParam(AccumulatorParam):
0201 
0202     """
0203     An AccumulatorParam that uses the + operators to add values. Designed for simple types
0204     such as integers, floats, and lists. Requires the zero value for the underlying type
0205     as a parameter.
0206     """
0207 
0208     def __init__(self, zero_value):
0209         self.zero_value = zero_value
0210 
0211     def zero(self, value):
0212         return self.zero_value
0213 
0214     def addInPlace(self, value1, value2):
0215         value1 += value2
0216         return value1
0217 
0218 
0219 # Singleton accumulator params for some standard types
0220 INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
0221 FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
0222 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
0223 
0224 
0225 class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
0226 
0227     """
0228     This handler will keep polling updates from the same socket until the
0229     server is shutdown.
0230     """
0231 
0232     def handle(self):
0233         from pyspark.accumulators import _accumulatorRegistry
0234         auth_token = self.server.auth_token
0235 
0236         def poll(func):
0237             while not self.server.server_shutdown:
0238                 # Poll every 1 second for new data -- don't block in case of shutdown.
0239                 r, _, _ = select.select([self.rfile], [], [], 1)
0240                 if self.rfile in r:
0241                     if func():
0242                         break
0243 
0244         def accum_updates():
0245             num_updates = read_int(self.rfile)
0246             for _ in range(num_updates):
0247                 (aid, update) = pickleSer._read_with_length(self.rfile)
0248                 _accumulatorRegistry[aid] += update
0249             # Write a byte in acknowledgement
0250             self.wfile.write(struct.pack("!b", 1))
0251             return False
0252 
0253         def authenticate_and_accum_updates():
0254             received_token = self.rfile.read(len(auth_token))
0255             if isinstance(received_token, bytes):
0256                 received_token = received_token.decode("utf-8")
0257             if (received_token == auth_token):
0258                 accum_updates()
0259                 # we've authenticated, we can break out of the first loop now
0260                 return True
0261             else:
0262                 raise Exception(
0263                     "The value of the provided token to the AccumulatorServer is not correct.")
0264 
0265         # first we keep polling till we've received the authentication token
0266         poll(authenticate_and_accum_updates)
0267         # now we've authenticated, don't need to check for the token anymore
0268         poll(accum_updates)
0269 
0270 
0271 class AccumulatorServer(SocketServer.TCPServer):
0272 
0273     def __init__(self, server_address, RequestHandlerClass, auth_token):
0274         SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
0275         self.auth_token = auth_token
0276 
0277     """
0278     A simple TCP server that intercepts shutdown() in order to interrupt
0279     our continuous polling on the handler.
0280     """
0281     server_shutdown = False
0282 
0283     def shutdown(self):
0284         self.server_shutdown = True
0285         SocketServer.TCPServer.shutdown(self)
0286         self.server_close()
0287 
0288 
0289 def _start_update_server(auth_token):
0290     """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
0291     server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)
0292     thread = threading.Thread(target=server.serve_forever)
0293     thread.daemon = True
0294     thread.start()
0295     return server
0296 
0297 if __name__ == "__main__":
0298     import doctest
0299     (failure_count, test_count) = doctest.testmod()
0300     if failure_count:
0301         sys.exit(-1)