0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 PySpark supports custom serializers for transferring data; this can improve
0020 performance.
0021
0022 By default, PySpark uses :class:`PickleSerializer` to serialize objects using Python's
0023 `cPickle` serializer, which can serialize nearly any Python object.
0024 Other serializers, like :class:`MarshalSerializer`, support fewer datatypes but can be
0025 faster.
0026
0027 The serializer is chosen when creating :class:`SparkContext`:
0028
0029 >>> from pyspark.context import SparkContext
0030 >>> from pyspark.serializers import MarshalSerializer
0031 >>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
0032 >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
0033 [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
0034 >>> sc.stop()
0035
0036 PySpark serializes objects in batches; by default, the batch size is chosen based
0037 on the size of objects and is also configurable by SparkContext's `batchSize`
0038 parameter:
0039
0040 >>> sc = SparkContext('local', 'test', batchSize=2)
0041 >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
0042
0043 Behind the scenes, this creates a JavaRDD with four partitions, each of
0044 which contains two batches of two objects:
0045
0046 >>> rdd.glom().collect()
0047 [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
0048 >>> int(rdd._jrdd.count())
0049 8
0050 >>> sc.stop()
0051 """
0052
0053 import sys
0054 from itertools import chain, product
0055 import marshal
0056 import struct
0057 import types
0058 import collections
0059 import zlib
0060 import itertools
0061
0062 if sys.version < '3':
0063 import cPickle as pickle
0064 from itertools import izip as zip, imap as map
0065 else:
0066 import pickle
0067 basestring = unicode = str
0068 xrange = range
0069 pickle_protocol = pickle.HIGHEST_PROTOCOL
0070
0071 from pyspark import cloudpickle
0072 from pyspark.util import _exception_message, print_exec
0073
0074
0075 __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
0076
0077
0078 class SpecialLengths(object):
0079 END_OF_DATA_SECTION = -1
0080 PYTHON_EXCEPTION_THROWN = -2
0081 TIMING_DATA = -3
0082 END_OF_STREAM = -4
0083 NULL = -5
0084 START_ARROW_STREAM = -6
0085
0086
0087 class Serializer(object):
0088
0089 def dump_stream(self, iterator, stream):
0090 """
0091 Serialize an iterator of objects to the output stream.
0092 """
0093 raise NotImplementedError
0094
0095 def load_stream(self, stream):
0096 """
0097 Return an iterator of deserialized objects from the input stream.
0098 """
0099 raise NotImplementedError
0100
0101 def _load_stream_without_unbatching(self, stream):
0102 """
0103 Return an iterator of deserialized batches (iterable) of objects from the input stream.
0104 If the serializer does not operate on batches the default implementation returns an
0105 iterator of single element lists.
0106 """
0107 return map(lambda x: [x], self.load_stream(stream))
0108
0109
0110
0111
0112
0113
0114
0115 def __eq__(self, other):
0116 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
0117
0118 def __ne__(self, other):
0119 return not self.__eq__(other)
0120
0121 def __repr__(self):
0122 return "%s()" % self.__class__.__name__
0123
0124 def __hash__(self):
0125 return hash(str(self))
0126
0127
0128 class FramedSerializer(Serializer):
0129
0130 """
0131 Serializer that writes objects as a stream of (length, data) pairs,
0132 where `length` is a 32-bit integer and data is `length` bytes.
0133 """
0134
0135 def __init__(self):
0136
0137
0138 self._only_write_strings = sys.version_info[0:2] <= (2, 6)
0139
0140 def dump_stream(self, iterator, stream):
0141 for obj in iterator:
0142 self._write_with_length(obj, stream)
0143
0144 def load_stream(self, stream):
0145 while True:
0146 try:
0147 yield self._read_with_length(stream)
0148 except EOFError:
0149 return
0150
0151 def _write_with_length(self, obj, stream):
0152 serialized = self.dumps(obj)
0153 if serialized is None:
0154 raise ValueError("serialized value should not be None")
0155 if len(serialized) > (1 << 31):
0156 raise ValueError("can not serialize object larger than 2G")
0157 write_int(len(serialized), stream)
0158 if self._only_write_strings:
0159 stream.write(str(serialized))
0160 else:
0161 stream.write(serialized)
0162
0163 def _read_with_length(self, stream):
0164 length = read_int(stream)
0165 if length == SpecialLengths.END_OF_DATA_SECTION:
0166 raise EOFError
0167 elif length == SpecialLengths.NULL:
0168 return None
0169 obj = stream.read(length)
0170 if len(obj) < length:
0171 raise EOFError
0172 return self.loads(obj)
0173
0174 def dumps(self, obj):
0175 """
0176 Serialize an object into a byte array.
0177 When batching is used, this will be called with an array of objects.
0178 """
0179 raise NotImplementedError
0180
0181 def loads(self, obj):
0182 """
0183 Deserialize an object from a byte array.
0184 """
0185 raise NotImplementedError
0186
0187
0188 class BatchedSerializer(Serializer):
0189
0190 """
0191 Serializes a stream of objects in batches by calling its wrapped
0192 Serializer with streams of objects.
0193 """
0194
0195 UNLIMITED_BATCH_SIZE = -1
0196 UNKNOWN_BATCH_SIZE = 0
0197
0198 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
0199 self.serializer = serializer
0200 self.batchSize = batchSize
0201
0202 def _batched(self, iterator):
0203 if self.batchSize == self.UNLIMITED_BATCH_SIZE:
0204 yield list(iterator)
0205 elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
0206 n = len(iterator)
0207 for i in xrange(0, n, self.batchSize):
0208 yield iterator[i: i + self.batchSize]
0209 else:
0210 items = []
0211 count = 0
0212 for item in iterator:
0213 items.append(item)
0214 count += 1
0215 if count == self.batchSize:
0216 yield items
0217 items = []
0218 count = 0
0219 if items:
0220 yield items
0221
0222 def dump_stream(self, iterator, stream):
0223 self.serializer.dump_stream(self._batched(iterator), stream)
0224
0225 def load_stream(self, stream):
0226 return chain.from_iterable(self._load_stream_without_unbatching(stream))
0227
0228 def _load_stream_without_unbatching(self, stream):
0229 return self.serializer.load_stream(stream)
0230
0231 def __repr__(self):
0232 return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
0233
0234
0235 class FlattenedValuesSerializer(BatchedSerializer):
0236
0237 """
0238 Serializes a stream of list of pairs, split the list of values
0239 which contain more than a certain number of objects to make them
0240 have similar sizes.
0241 """
0242 def __init__(self, serializer, batchSize=10):
0243 BatchedSerializer.__init__(self, serializer, batchSize)
0244
0245 def _batched(self, iterator):
0246 n = self.batchSize
0247 for key, values in iterator:
0248 for i in range(0, len(values), n):
0249 yield key, values[i:i + n]
0250
0251 def load_stream(self, stream):
0252 return self.serializer.load_stream(stream)
0253
0254 def __repr__(self):
0255 return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize)
0256
0257
0258 class AutoBatchedSerializer(BatchedSerializer):
0259 """
0260 Choose the size of batch automatically based on the size of object
0261 """
0262
0263 def __init__(self, serializer, bestSize=1 << 16):
0264 BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE)
0265 self.bestSize = bestSize
0266
0267 def dump_stream(self, iterator, stream):
0268 batch, best = 1, self.bestSize
0269 iterator = iter(iterator)
0270 while True:
0271 vs = list(itertools.islice(iterator, batch))
0272 if not vs:
0273 break
0274
0275 bytes = self.serializer.dumps(vs)
0276 write_int(len(bytes), stream)
0277 stream.write(bytes)
0278
0279 size = len(bytes)
0280 if size < best:
0281 batch *= 2
0282 elif size > best * 10 and batch > 1:
0283 batch //= 2
0284
0285 def __repr__(self):
0286 return "AutoBatchedSerializer(%s)" % self.serializer
0287
0288
0289 class CartesianDeserializer(Serializer):
0290
0291 """
0292 Deserializes the JavaRDD cartesian() of two PythonRDDs.
0293 Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
0294 we additionally need to do the cartesian within each pair of batches.
0295 """
0296
0297 def __init__(self, key_ser, val_ser):
0298 self.key_ser = key_ser
0299 self.val_ser = val_ser
0300
0301 def _load_stream_without_unbatching(self, stream):
0302 key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
0303 val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
0304 for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
0305
0306 yield product(key_batch, val_batch)
0307
0308 def load_stream(self, stream):
0309 return chain.from_iterable(self._load_stream_without_unbatching(stream))
0310
0311 def __repr__(self):
0312 return "CartesianDeserializer(%s, %s)" % \
0313 (str(self.key_ser), str(self.val_ser))
0314
0315
0316 class PairDeserializer(Serializer):
0317
0318 """
0319 Deserializes the JavaRDD zip() of two PythonRDDs.
0320 Due to pyspark batching we cannot simply use the result of the Java RDD zip,
0321 we additionally need to do the zip within each pair of batches.
0322 """
0323
0324 def __init__(self, key_ser, val_ser):
0325 self.key_ser = key_ser
0326 self.val_ser = val_ser
0327
0328 def _load_stream_without_unbatching(self, stream):
0329 key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
0330 val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
0331 for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
0332
0333
0334 key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch)
0335 val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch)
0336 if len(key_batch) != len(val_batch):
0337 raise ValueError("Can not deserialize PairRDD with different number of items"
0338 " in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
0339
0340 yield zip(key_batch, val_batch)
0341
0342 def load_stream(self, stream):
0343 return chain.from_iterable(self._load_stream_without_unbatching(stream))
0344
0345 def __repr__(self):
0346 return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
0347
0348
0349 class NoOpSerializer(FramedSerializer):
0350
0351 def loads(self, obj):
0352 return obj
0353
0354 def dumps(self, obj):
0355 return obj
0356
0357
0358
0359
0360 __cls = {}
0361
0362
0363 def _restore(name, fields, value):
0364 """ Restore an object of namedtuple"""
0365 k = (name, fields)
0366 cls = __cls.get(k)
0367 if cls is None:
0368 cls = collections.namedtuple(name, fields)
0369 __cls[k] = cls
0370 return cls(*value)
0371
0372
0373 def _hack_namedtuple(cls):
0374 """ Make class generated by namedtuple picklable """
0375 name = cls.__name__
0376 fields = cls._fields
0377
0378 def __reduce__(self):
0379 return (_restore, (name, fields, tuple(self)))
0380 cls.__reduce__ = __reduce__
0381 cls._is_namedtuple_ = True
0382 return cls
0383
0384
0385 def _hijack_namedtuple():
0386 """ Hack namedtuple() to make it picklable """
0387
0388 if hasattr(collections.namedtuple, "__hijack"):
0389 return
0390
0391 global _old_namedtuple
0392 global _old_namedtuple_kwdefaults
0393
0394 def _copy_func(f):
0395 return types.FunctionType(f.__code__, f.__globals__, f.__name__,
0396 f.__defaults__, f.__closure__)
0397
0398 def _kwdefaults(f):
0399
0400
0401
0402
0403
0404
0405
0406
0407 kargs = getattr(f, "__kwdefaults__", None)
0408 if kargs is None:
0409 return {}
0410 else:
0411 return kargs
0412
0413 _old_namedtuple = _copy_func(collections.namedtuple)
0414 _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple)
0415
0416 def namedtuple(*args, **kwargs):
0417 for k, v in _old_namedtuple_kwdefaults.items():
0418 kwargs[k] = kwargs.get(k, v)
0419 cls = _old_namedtuple(*args, **kwargs)
0420 return _hack_namedtuple(cls)
0421
0422
0423 collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
0424 collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
0425 collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
0426 collections.namedtuple.__code__ = namedtuple.__code__
0427 collections.namedtuple.__hijack = 1
0428
0429
0430
0431
0432 for n, o in sys.modules["__main__"].__dict__.items():
0433 if (type(o) is type and o.__base__ is tuple
0434 and hasattr(o, "_fields")
0435 and "__reduce__" not in o.__dict__):
0436 _hack_namedtuple(o)
0437
0438
0439 _hijack_namedtuple()
0440
0441
0442 class PickleSerializer(FramedSerializer):
0443
0444 """
0445 Serializes objects using Python's pickle serializer:
0446
0447 http://docs.python.org/2/library/pickle.html
0448
0449 This serializer supports nearly any Python object, but may
0450 not be as fast as more specialized serializers.
0451 """
0452
0453 def dumps(self, obj):
0454 return pickle.dumps(obj, pickle_protocol)
0455
0456 if sys.version >= '3':
0457 def loads(self, obj, encoding="bytes"):
0458 return pickle.loads(obj, encoding=encoding)
0459 else:
0460 def loads(self, obj, encoding=None):
0461 return pickle.loads(obj)
0462
0463
0464 class CloudPickleSerializer(PickleSerializer):
0465
0466 def dumps(self, obj):
0467 try:
0468 return cloudpickle.dumps(obj, pickle_protocol)
0469 except pickle.PickleError:
0470 raise
0471 except Exception as e:
0472 emsg = _exception_message(e)
0473 if "'i' format requires" in emsg:
0474 msg = "Object too large to serialize: %s" % emsg
0475 else:
0476 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
0477 print_exec(sys.stderr)
0478 raise pickle.PicklingError(msg)
0479
0480
0481 class MarshalSerializer(FramedSerializer):
0482
0483 """
0484 Serializes objects using Python's Marshal serializer:
0485
0486 http://docs.python.org/2/library/marshal.html
0487
0488 This serializer is faster than PickleSerializer but supports fewer datatypes.
0489 """
0490
0491 def dumps(self, obj):
0492 return marshal.dumps(obj)
0493
0494 def loads(self, obj):
0495 return marshal.loads(obj)
0496
0497
0498 class AutoSerializer(FramedSerializer):
0499
0500 """
0501 Choose marshal or pickle as serialization protocol automatically
0502 """
0503
0504 def __init__(self):
0505 FramedSerializer.__init__(self)
0506 self._type = None
0507
0508 def dumps(self, obj):
0509 if self._type is not None:
0510 return b'P' + pickle.dumps(obj, -1)
0511 try:
0512 return b'M' + marshal.dumps(obj)
0513 except Exception:
0514 self._type = b'P'
0515 return b'P' + pickle.dumps(obj, -1)
0516
0517 def loads(self, obj):
0518 _type = obj[0]
0519 if _type == b'M':
0520 return marshal.loads(obj[1:])
0521 elif _type == b'P':
0522 return pickle.loads(obj[1:])
0523 else:
0524 raise ValueError("invalid serialization type: %s" % _type)
0525
0526
0527 class CompressedSerializer(FramedSerializer):
0528 """
0529 Compress the serialized data
0530 """
0531 def __init__(self, serializer):
0532 FramedSerializer.__init__(self)
0533 assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer"
0534 self.serializer = serializer
0535
0536 def dumps(self, obj):
0537 return zlib.compress(self.serializer.dumps(obj), 1)
0538
0539 def loads(self, obj):
0540 return self.serializer.loads(zlib.decompress(obj))
0541
0542 def __repr__(self):
0543 return "CompressedSerializer(%s)" % self.serializer
0544
0545
0546 class UTF8Deserializer(Serializer):
0547
0548 """
0549 Deserializes streams written by String.getBytes.
0550 """
0551
0552 def __init__(self, use_unicode=True):
0553 self.use_unicode = use_unicode
0554
0555 def loads(self, stream):
0556 length = read_int(stream)
0557 if length == SpecialLengths.END_OF_DATA_SECTION:
0558 raise EOFError
0559 elif length == SpecialLengths.NULL:
0560 return None
0561 s = stream.read(length)
0562 return s.decode("utf-8") if self.use_unicode else s
0563
0564 def load_stream(self, stream):
0565 try:
0566 while True:
0567 yield self.loads(stream)
0568 except struct.error:
0569 return
0570 except EOFError:
0571 return
0572
0573 def __repr__(self):
0574 return "UTF8Deserializer(%s)" % self.use_unicode
0575
0576
0577 def read_long(stream):
0578 length = stream.read(8)
0579 if not length:
0580 raise EOFError
0581 return struct.unpack("!q", length)[0]
0582
0583
0584 def write_long(value, stream):
0585 stream.write(struct.pack("!q", value))
0586
0587
0588 def pack_long(value):
0589 return struct.pack("!q", value)
0590
0591
0592 def read_int(stream):
0593 length = stream.read(4)
0594 if not length:
0595 raise EOFError
0596 return struct.unpack("!i", length)[0]
0597
0598
0599 def write_int(value, stream):
0600 stream.write(struct.pack("!i", value))
0601
0602
0603 def read_bool(stream):
0604 length = stream.read(1)
0605 if not length:
0606 raise EOFError
0607 return struct.unpack("!?", length)[0]
0608
0609
0610 def write_with_length(obj, stream):
0611 write_int(len(obj), stream)
0612 stream.write(obj)
0613
0614
0615 class ChunkedStream(object):
0616
0617 """
0618 This is a file-like object takes a stream of data, of unknown length, and breaks it into fixed
0619 length frames. The intended use case is serializing large data and sending it immediately over
0620 a socket -- we do not want to buffer the entire data before sending it, but the receiving end
0621 needs to know whether or not there is more data coming.
0622
0623 It works by buffering the incoming data in some fixed-size chunks. If the buffer is full, it
0624 first sends the buffer size, then the data. This repeats as long as there is more data to send.
0625 When this is closed, it sends the length of whatever data is in the buffer, then that data, and
0626 finally a "length" of -1 to indicate the stream has completed.
0627 """
0628
0629 def __init__(self, wrapped, buffer_size):
0630 self.buffer_size = buffer_size
0631 self.buffer = bytearray(buffer_size)
0632 self.current_pos = 0
0633 self.wrapped = wrapped
0634
0635 def write(self, bytes):
0636 byte_pos = 0
0637 byte_remaining = len(bytes)
0638 while byte_remaining > 0:
0639 new_pos = byte_remaining + self.current_pos
0640 if new_pos < self.buffer_size:
0641
0642 self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
0643 self.current_pos = new_pos
0644 byte_remaining = 0
0645 else:
0646
0647 space_left = self.buffer_size - self.current_pos
0648 new_byte_pos = byte_pos + space_left
0649 self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
0650 write_int(self.buffer_size, self.wrapped)
0651 self.wrapped.write(self.buffer)
0652 byte_remaining -= space_left
0653 byte_pos = new_byte_pos
0654 self.current_pos = 0
0655
0656 def close(self):
0657
0658 if self.current_pos > 0:
0659 write_int(self.current_pos, self.wrapped)
0660 self.wrapped.write(self.buffer[:self.current_pos])
0661
0662 write_int(-1, self.wrapped)
0663 self.wrapped.close()
0664
0665 @property
0666 def closed(self):
0667 """
0668 Return True if the `wrapped` object has been closed.
0669 NOTE: this property is required by pyarrow to be used as a file-like object in
0670 pyarrow.RecordBatchStreamWriter from ArrowStreamSerializer
0671 """
0672 return self.wrapped.closed
0673
0674
0675 if __name__ == '__main__':
0676 import doctest
0677 (failure_count, test_count) = doctest.testmod()
0678 if failure_count:
0679 sys.exit(-1)