0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import gc
0019 import os
0020 import sys
0021 from tempfile import NamedTemporaryFile
0022 import threading
0023
0024 from pyspark.java_gateway import local_connect_and_auth
0025 from pyspark.serializers import ChunkedStream, pickle_protocol
0026 from pyspark.util import _exception_message, print_exec
0027
0028 if sys.version < '3':
0029 import cPickle as pickle
0030 else:
0031 import pickle
0032 unicode = str
0033
0034 __all__ = ['Broadcast']
0035
0036
0037
0038 _broadcastRegistry = {}
0039
0040
0041 def _from_id(bid):
0042 from pyspark.broadcast import _broadcastRegistry
0043 if bid not in _broadcastRegistry:
0044 raise Exception("Broadcast variable '%s' not loaded!" % bid)
0045 return _broadcastRegistry[bid]
0046
0047
0048 class Broadcast(object):
0049
0050 """
0051 A broadcast variable created with :meth:`SparkContext.broadcast`.
0052 Access its value through :attr:`value`.
0053
0054 Examples:
0055
0056 >>> from pyspark.context import SparkContext
0057 >>> sc = SparkContext('local', 'test')
0058 >>> b = sc.broadcast([1, 2, 3, 4, 5])
0059 >>> b.value
0060 [1, 2, 3, 4, 5]
0061 >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
0062 [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
0063 >>> b.unpersist()
0064
0065 >>> large_broadcast = sc.broadcast(range(10000))
0066 """
0067
0068 def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
0069 sock_file=None):
0070 """
0071 Should not be called directly by users -- use :meth:`SparkContext.broadcast`
0072 instead.
0073 """
0074 if sc is not None:
0075
0076 f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
0077 self._path = f.name
0078 self._sc = sc
0079 self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
0080 if sc._encryption_enabled:
0081
0082
0083 port, auth_secret = self._python_broadcast.setupEncryptionServer()
0084 (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
0085 broadcast_out = ChunkedStream(encryption_sock_file, 8192)
0086 else:
0087
0088 broadcast_out = f
0089 self.dump(value, broadcast_out)
0090 if sc._encryption_enabled:
0091 self._python_broadcast.waitTillDataReceived()
0092 self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
0093 self._pickle_registry = pickle_registry
0094 else:
0095
0096 self._jbroadcast = None
0097 self._sc = None
0098 self._python_broadcast = None
0099 if sock_file is not None:
0100
0101
0102 self._value = self.load(sock_file)
0103 else:
0104
0105
0106 assert(path is not None)
0107 self._path = path
0108
0109 def dump(self, value, f):
0110 try:
0111 pickle.dump(value, f, pickle_protocol)
0112 except pickle.PickleError:
0113 raise
0114 except Exception as e:
0115 msg = "Could not serialize broadcast: %s: %s" \
0116 % (e.__class__.__name__, _exception_message(e))
0117 print_exec(sys.stderr)
0118 raise pickle.PicklingError(msg)
0119 f.close()
0120
0121 def load_from_path(self, path):
0122 with open(path, 'rb', 1 << 20) as f:
0123 return self.load(f)
0124
0125 def load(self, file):
0126
0127 gc.disable()
0128 try:
0129 return pickle.load(file)
0130 finally:
0131 gc.enable()
0132
0133 @property
0134 def value(self):
0135 """ Return the broadcasted value
0136 """
0137 if not hasattr(self, "_value") and self._path is not None:
0138
0139
0140 if self._sc is not None and self._sc._encryption_enabled:
0141 port, auth_secret = self._python_broadcast.setupDecryptionServer()
0142 (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret)
0143 self._python_broadcast.waitTillBroadcastDataSent()
0144 return self.load(decrypted_sock_file)
0145 else:
0146 self._value = self.load_from_path(self._path)
0147 return self._value
0148
0149 def unpersist(self, blocking=False):
0150 """
0151 Delete cached copies of this broadcast on the executors. If the
0152 broadcast is used after this is called, it will need to be
0153 re-sent to each executor.
0154
0155 :param blocking: Whether to block until unpersisting has completed
0156 """
0157 if self._jbroadcast is None:
0158 raise Exception("Broadcast can only be unpersisted in driver")
0159 self._jbroadcast.unpersist(blocking)
0160
0161 def destroy(self, blocking=False):
0162 """
0163 Destroy all data and metadata related to this broadcast variable.
0164 Use this with caution; once a broadcast variable has been destroyed,
0165 it cannot be used again.
0166
0167 .. versionchanged:: 3.0.0
0168 Added optional argument `blocking` to specify whether to block until all
0169 blocks are deleted.
0170 """
0171 if self._jbroadcast is None:
0172 raise Exception("Broadcast can only be destroyed in driver")
0173 self._jbroadcast.destroy(blocking)
0174 os.unlink(self._path)
0175
0176 def __reduce__(self):
0177 if self._jbroadcast is None:
0178 raise Exception("Broadcast can only be serialized in driver")
0179 self._pickle_registry.add(self)
0180 return _from_id, (self._jbroadcast.id(),)
0181
0182
0183 class BroadcastPickleRegistry(threading.local):
0184 """ Thread-local registry for broadcast variables that have been pickled
0185 """
0186
0187 def __init__(self):
0188 self.__dict__.setdefault("_registry", set())
0189
0190 def __iter__(self):
0191 for bcast in self._registry:
0192 yield bcast
0193
0194 def add(self, bcast):
0195 self._registry.add(bcast)
0196
0197 def clear(self):
0198 self._registry.clear()
0199
0200
0201 if __name__ == "__main__":
0202 import doctest
0203 (failure_count, test_count) = doctest.testmod()
0204 if failure_count:
0205 sys.exit(-1)