Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 PySpark supports custom serializers for transferring data; this can improve
0020 performance.
0021 
0022 By default, PySpark uses :class:`PickleSerializer` to serialize objects using Python's
0023 `cPickle` serializer, which can serialize nearly any Python object.
0024 Other serializers, like :class:`MarshalSerializer`, support fewer datatypes but can be
0025 faster.
0026 
0027 The serializer is chosen when creating :class:`SparkContext`:
0028 
0029 >>> from pyspark.context import SparkContext
0030 >>> from pyspark.serializers import MarshalSerializer
0031 >>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
0032 >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
0033 [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
0034 >>> sc.stop()
0035 
0036 PySpark serializes objects in batches; by default, the batch size is chosen based
0037 on the size of objects and is also configurable by SparkContext's `batchSize`
0038 parameter:
0039 
0040 >>> sc = SparkContext('local', 'test', batchSize=2)
0041 >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
0042 
0043 Behind the scenes, this creates a JavaRDD with four partitions, each of
0044 which contains two batches of two objects:
0045 
0046 >>> rdd.glom().collect()
0047 [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
0048 >>> int(rdd._jrdd.count())
0049 8
0050 >>> sc.stop()
0051 """
0052 
0053 import sys
0054 from itertools import chain, product
0055 import marshal
0056 import struct
0057 import types
0058 import collections
0059 import zlib
0060 import itertools
0061 
0062 if sys.version < '3':
0063     import cPickle as pickle
0064     from itertools import izip as zip, imap as map
0065 else:
0066     import pickle
0067     basestring = unicode = str
0068     xrange = range
0069 pickle_protocol = pickle.HIGHEST_PROTOCOL
0070 
0071 from pyspark import cloudpickle
0072 from pyspark.util import _exception_message, print_exec
0073 
0074 
0075 __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
0076 
0077 
0078 class SpecialLengths(object):
0079     END_OF_DATA_SECTION = -1
0080     PYTHON_EXCEPTION_THROWN = -2
0081     TIMING_DATA = -3
0082     END_OF_STREAM = -4
0083     NULL = -5
0084     START_ARROW_STREAM = -6
0085 
0086 
0087 class Serializer(object):
0088 
0089     def dump_stream(self, iterator, stream):
0090         """
0091         Serialize an iterator of objects to the output stream.
0092         """
0093         raise NotImplementedError
0094 
0095     def load_stream(self, stream):
0096         """
0097         Return an iterator of deserialized objects from the input stream.
0098         """
0099         raise NotImplementedError
0100 
0101     def _load_stream_without_unbatching(self, stream):
0102         """
0103         Return an iterator of deserialized batches (iterable) of objects from the input stream.
0104         If the serializer does not operate on batches the default implementation returns an
0105         iterator of single element lists.
0106         """
0107         return map(lambda x: [x], self.load_stream(stream))
0108 
0109     # Note: our notion of "equality" is that output generated by
0110     # equal serializers can be deserialized using the same serializer.
0111 
0112     # This default implementation handles the simple cases;
0113     # subclasses should override __eq__ as appropriate.
0114 
0115     def __eq__(self, other):
0116         return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
0117 
0118     def __ne__(self, other):
0119         return not self.__eq__(other)
0120 
0121     def __repr__(self):
0122         return "%s()" % self.__class__.__name__
0123 
0124     def __hash__(self):
0125         return hash(str(self))
0126 
0127 
0128 class FramedSerializer(Serializer):
0129 
0130     """
0131     Serializer that writes objects as a stream of (length, data) pairs,
0132     where `length` is a 32-bit integer and data is `length` bytes.
0133     """
0134 
0135     def __init__(self):
0136         # On Python 2.6, we can't write bytearrays to streams, so we need to convert them
0137         # to strings first. Check if the version number is that old.
0138         self._only_write_strings = sys.version_info[0:2] <= (2, 6)
0139 
0140     def dump_stream(self, iterator, stream):
0141         for obj in iterator:
0142             self._write_with_length(obj, stream)
0143 
0144     def load_stream(self, stream):
0145         while True:
0146             try:
0147                 yield self._read_with_length(stream)
0148             except EOFError:
0149                 return
0150 
0151     def _write_with_length(self, obj, stream):
0152         serialized = self.dumps(obj)
0153         if serialized is None:
0154             raise ValueError("serialized value should not be None")
0155         if len(serialized) > (1 << 31):
0156             raise ValueError("can not serialize object larger than 2G")
0157         write_int(len(serialized), stream)
0158         if self._only_write_strings:
0159             stream.write(str(serialized))
0160         else:
0161             stream.write(serialized)
0162 
0163     def _read_with_length(self, stream):
0164         length = read_int(stream)
0165         if length == SpecialLengths.END_OF_DATA_SECTION:
0166             raise EOFError
0167         elif length == SpecialLengths.NULL:
0168             return None
0169         obj = stream.read(length)
0170         if len(obj) < length:
0171             raise EOFError
0172         return self.loads(obj)
0173 
0174     def dumps(self, obj):
0175         """
0176         Serialize an object into a byte array.
0177         When batching is used, this will be called with an array of objects.
0178         """
0179         raise NotImplementedError
0180 
0181     def loads(self, obj):
0182         """
0183         Deserialize an object from a byte array.
0184         """
0185         raise NotImplementedError
0186 
0187 
0188 class BatchedSerializer(Serializer):
0189 
0190     """
0191     Serializes a stream of objects in batches by calling its wrapped
0192     Serializer with streams of objects.
0193     """
0194 
0195     UNLIMITED_BATCH_SIZE = -1
0196     UNKNOWN_BATCH_SIZE = 0
0197 
0198     def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
0199         self.serializer = serializer
0200         self.batchSize = batchSize
0201 
0202     def _batched(self, iterator):
0203         if self.batchSize == self.UNLIMITED_BATCH_SIZE:
0204             yield list(iterator)
0205         elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
0206             n = len(iterator)
0207             for i in xrange(0, n, self.batchSize):
0208                 yield iterator[i: i + self.batchSize]
0209         else:
0210             items = []
0211             count = 0
0212             for item in iterator:
0213                 items.append(item)
0214                 count += 1
0215                 if count == self.batchSize:
0216                     yield items
0217                     items = []
0218                     count = 0
0219             if items:
0220                 yield items
0221 
0222     def dump_stream(self, iterator, stream):
0223         self.serializer.dump_stream(self._batched(iterator), stream)
0224 
0225     def load_stream(self, stream):
0226         return chain.from_iterable(self._load_stream_without_unbatching(stream))
0227 
0228     def _load_stream_without_unbatching(self, stream):
0229         return self.serializer.load_stream(stream)
0230 
0231     def __repr__(self):
0232         return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
0233 
0234 
0235 class FlattenedValuesSerializer(BatchedSerializer):
0236 
0237     """
0238     Serializes a stream of list of pairs, split the list of values
0239     which contain more than a certain number of objects to make them
0240     have similar sizes.
0241     """
0242     def __init__(self, serializer, batchSize=10):
0243         BatchedSerializer.__init__(self, serializer, batchSize)
0244 
0245     def _batched(self, iterator):
0246         n = self.batchSize
0247         for key, values in iterator:
0248             for i in range(0, len(values), n):
0249                 yield key, values[i:i + n]
0250 
0251     def load_stream(self, stream):
0252         return self.serializer.load_stream(stream)
0253 
0254     def __repr__(self):
0255         return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize)
0256 
0257 
0258 class AutoBatchedSerializer(BatchedSerializer):
0259     """
0260     Choose the size of batch automatically based on the size of object
0261     """
0262 
0263     def __init__(self, serializer, bestSize=1 << 16):
0264         BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE)
0265         self.bestSize = bestSize
0266 
0267     def dump_stream(self, iterator, stream):
0268         batch, best = 1, self.bestSize
0269         iterator = iter(iterator)
0270         while True:
0271             vs = list(itertools.islice(iterator, batch))
0272             if not vs:
0273                 break
0274 
0275             bytes = self.serializer.dumps(vs)
0276             write_int(len(bytes), stream)
0277             stream.write(bytes)
0278 
0279             size = len(bytes)
0280             if size < best:
0281                 batch *= 2
0282             elif size > best * 10 and batch > 1:
0283                 batch //= 2
0284 
0285     def __repr__(self):
0286         return "AutoBatchedSerializer(%s)" % self.serializer
0287 
0288 
0289 class CartesianDeserializer(Serializer):
0290 
0291     """
0292     Deserializes the JavaRDD cartesian() of two PythonRDDs.
0293     Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
0294     we additionally need to do the cartesian within each pair of batches.
0295     """
0296 
0297     def __init__(self, key_ser, val_ser):
0298         self.key_ser = key_ser
0299         self.val_ser = val_ser
0300 
0301     def _load_stream_without_unbatching(self, stream):
0302         key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
0303         val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
0304         for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
0305             # for correctness with repeated cartesian/zip this must be returned as one batch
0306             yield product(key_batch, val_batch)
0307 
0308     def load_stream(self, stream):
0309         return chain.from_iterable(self._load_stream_without_unbatching(stream))
0310 
0311     def __repr__(self):
0312         return "CartesianDeserializer(%s, %s)" % \
0313                (str(self.key_ser), str(self.val_ser))
0314 
0315 
0316 class PairDeserializer(Serializer):
0317 
0318     """
0319     Deserializes the JavaRDD zip() of two PythonRDDs.
0320     Due to pyspark batching we cannot simply use the result of the Java RDD zip,
0321     we additionally need to do the zip within each pair of batches.
0322     """
0323 
0324     def __init__(self, key_ser, val_ser):
0325         self.key_ser = key_ser
0326         self.val_ser = val_ser
0327 
0328     def _load_stream_without_unbatching(self, stream):
0329         key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
0330         val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
0331         for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
0332             # For double-zipped RDDs, the batches can be iterators from other PairDeserializer,
0333             # instead of lists. We need to convert them to lists if needed.
0334             key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch)
0335             val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch)
0336             if len(key_batch) != len(val_batch):
0337                 raise ValueError("Can not deserialize PairRDD with different number of items"
0338                                  " in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
0339             # for correctness with repeated cartesian/zip this must be returned as one batch
0340             yield zip(key_batch, val_batch)
0341 
0342     def load_stream(self, stream):
0343         return chain.from_iterable(self._load_stream_without_unbatching(stream))
0344 
0345     def __repr__(self):
0346         return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
0347 
0348 
0349 class NoOpSerializer(FramedSerializer):
0350 
0351     def loads(self, obj):
0352         return obj
0353 
0354     def dumps(self, obj):
0355         return obj
0356 
0357 
0358 # Hack namedtuple, make it picklable
0359 
0360 __cls = {}
0361 
0362 
0363 def _restore(name, fields, value):
0364     """ Restore an object of namedtuple"""
0365     k = (name, fields)
0366     cls = __cls.get(k)
0367     if cls is None:
0368         cls = collections.namedtuple(name, fields)
0369         __cls[k] = cls
0370     return cls(*value)
0371 
0372 
0373 def _hack_namedtuple(cls):
0374     """ Make class generated by namedtuple picklable """
0375     name = cls.__name__
0376     fields = cls._fields
0377 
0378     def __reduce__(self):
0379         return (_restore, (name, fields, tuple(self)))
0380     cls.__reduce__ = __reduce__
0381     cls._is_namedtuple_ = True
0382     return cls
0383 
0384 
0385 def _hijack_namedtuple():
0386     """ Hack namedtuple() to make it picklable """
0387     # hijack only one time
0388     if hasattr(collections.namedtuple, "__hijack"):
0389         return
0390 
0391     global _old_namedtuple  # or it will put in closure
0392     global _old_namedtuple_kwdefaults  # or it will put in closure too
0393 
0394     def _copy_func(f):
0395         return types.FunctionType(f.__code__, f.__globals__, f.__name__,
0396                                   f.__defaults__, f.__closure__)
0397 
0398     def _kwdefaults(f):
0399         # __kwdefaults__ contains the default values of keyword-only arguments which are
0400         # introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple
0401         # are as below:
0402         #
0403         # - Does not exist in Python 2.
0404         # - Returns None in <= Python 3.5.x.
0405         # - Returns a dictionary containing the default values to the keys from Python 3.6.x
0406         #    (See https://bugs.python.org/issue25628).
0407         kargs = getattr(f, "__kwdefaults__", None)
0408         if kargs is None:
0409             return {}
0410         else:
0411             return kargs
0412 
0413     _old_namedtuple = _copy_func(collections.namedtuple)
0414     _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple)
0415 
0416     def namedtuple(*args, **kwargs):
0417         for k, v in _old_namedtuple_kwdefaults.items():
0418             kwargs[k] = kwargs.get(k, v)
0419         cls = _old_namedtuple(*args, **kwargs)
0420         return _hack_namedtuple(cls)
0421 
0422     # replace namedtuple with the new one
0423     collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
0424     collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
0425     collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
0426     collections.namedtuple.__code__ = namedtuple.__code__
0427     collections.namedtuple.__hijack = 1
0428 
0429     # hack the cls already generated by namedtuple.
0430     # Those created in other modules can be pickled as normal,
0431     # so only hack those in __main__ module
0432     for n, o in sys.modules["__main__"].__dict__.items():
0433         if (type(o) is type and o.__base__ is tuple
0434                 and hasattr(o, "_fields")
0435                 and "__reduce__" not in o.__dict__):
0436             _hack_namedtuple(o)  # hack inplace
0437 
0438 
0439 _hijack_namedtuple()
0440 
0441 
0442 class PickleSerializer(FramedSerializer):
0443 
0444     """
0445     Serializes objects using Python's pickle serializer:
0446 
0447         http://docs.python.org/2/library/pickle.html
0448 
0449     This serializer supports nearly any Python object, but may
0450     not be as fast as more specialized serializers.
0451     """
0452 
0453     def dumps(self, obj):
0454         return pickle.dumps(obj, pickle_protocol)
0455 
0456     if sys.version >= '3':
0457         def loads(self, obj, encoding="bytes"):
0458             return pickle.loads(obj, encoding=encoding)
0459     else:
0460         def loads(self, obj, encoding=None):
0461             return pickle.loads(obj)
0462 
0463 
0464 class CloudPickleSerializer(PickleSerializer):
0465 
0466     def dumps(self, obj):
0467         try:
0468             return cloudpickle.dumps(obj, pickle_protocol)
0469         except pickle.PickleError:
0470             raise
0471         except Exception as e:
0472             emsg = _exception_message(e)
0473             if "'i' format requires" in emsg:
0474                 msg = "Object too large to serialize: %s" % emsg
0475             else:
0476                 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
0477             print_exec(sys.stderr)
0478             raise pickle.PicklingError(msg)
0479 
0480 
0481 class MarshalSerializer(FramedSerializer):
0482 
0483     """
0484     Serializes objects using Python's Marshal serializer:
0485 
0486         http://docs.python.org/2/library/marshal.html
0487 
0488     This serializer is faster than PickleSerializer but supports fewer datatypes.
0489     """
0490 
0491     def dumps(self, obj):
0492         return marshal.dumps(obj)
0493 
0494     def loads(self, obj):
0495         return marshal.loads(obj)
0496 
0497 
0498 class AutoSerializer(FramedSerializer):
0499 
0500     """
0501     Choose marshal or pickle as serialization protocol automatically
0502     """
0503 
0504     def __init__(self):
0505         FramedSerializer.__init__(self)
0506         self._type = None
0507 
0508     def dumps(self, obj):
0509         if self._type is not None:
0510             return b'P' + pickle.dumps(obj, -1)
0511         try:
0512             return b'M' + marshal.dumps(obj)
0513         except Exception:
0514             self._type = b'P'
0515             return b'P' + pickle.dumps(obj, -1)
0516 
0517     def loads(self, obj):
0518         _type = obj[0]
0519         if _type == b'M':
0520             return marshal.loads(obj[1:])
0521         elif _type == b'P':
0522             return pickle.loads(obj[1:])
0523         else:
0524             raise ValueError("invalid serialization type: %s" % _type)
0525 
0526 
0527 class CompressedSerializer(FramedSerializer):
0528     """
0529     Compress the serialized data
0530     """
0531     def __init__(self, serializer):
0532         FramedSerializer.__init__(self)
0533         assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer"
0534         self.serializer = serializer
0535 
0536     def dumps(self, obj):
0537         return zlib.compress(self.serializer.dumps(obj), 1)
0538 
0539     def loads(self, obj):
0540         return self.serializer.loads(zlib.decompress(obj))
0541 
0542     def __repr__(self):
0543         return "CompressedSerializer(%s)" % self.serializer
0544 
0545 
0546 class UTF8Deserializer(Serializer):
0547 
0548     """
0549     Deserializes streams written by String.getBytes.
0550     """
0551 
0552     def __init__(self, use_unicode=True):
0553         self.use_unicode = use_unicode
0554 
0555     def loads(self, stream):
0556         length = read_int(stream)
0557         if length == SpecialLengths.END_OF_DATA_SECTION:
0558             raise EOFError
0559         elif length == SpecialLengths.NULL:
0560             return None
0561         s = stream.read(length)
0562         return s.decode("utf-8") if self.use_unicode else s
0563 
0564     def load_stream(self, stream):
0565         try:
0566             while True:
0567                 yield self.loads(stream)
0568         except struct.error:
0569             return
0570         except EOFError:
0571             return
0572 
0573     def __repr__(self):
0574         return "UTF8Deserializer(%s)" % self.use_unicode
0575 
0576 
0577 def read_long(stream):
0578     length = stream.read(8)
0579     if not length:
0580         raise EOFError
0581     return struct.unpack("!q", length)[0]
0582 
0583 
0584 def write_long(value, stream):
0585     stream.write(struct.pack("!q", value))
0586 
0587 
0588 def pack_long(value):
0589     return struct.pack("!q", value)
0590 
0591 
0592 def read_int(stream):
0593     length = stream.read(4)
0594     if not length:
0595         raise EOFError
0596     return struct.unpack("!i", length)[0]
0597 
0598 
0599 def write_int(value, stream):
0600     stream.write(struct.pack("!i", value))
0601 
0602 
0603 def read_bool(stream):
0604     length = stream.read(1)
0605     if not length:
0606         raise EOFError
0607     return struct.unpack("!?", length)[0]
0608 
0609 
0610 def write_with_length(obj, stream):
0611     write_int(len(obj), stream)
0612     stream.write(obj)
0613 
0614 
0615 class ChunkedStream(object):
0616 
0617     """
0618     This is a file-like object takes a stream of data, of unknown length, and breaks it into fixed
0619     length frames.  The intended use case is serializing large data and sending it immediately over
0620     a socket -- we do not want to buffer the entire data before sending it, but the receiving end
0621     needs to know whether or not there is more data coming.
0622 
0623     It works by buffering the incoming data in some fixed-size chunks.  If the buffer is full, it
0624     first sends the buffer size, then the data.  This repeats as long as there is more data to send.
0625     When this is closed, it sends the length of whatever data is in the buffer, then that data, and
0626     finally a "length" of -1 to indicate the stream has completed.
0627     """
0628 
0629     def __init__(self, wrapped, buffer_size):
0630         self.buffer_size = buffer_size
0631         self.buffer = bytearray(buffer_size)
0632         self.current_pos = 0
0633         self.wrapped = wrapped
0634 
0635     def write(self, bytes):
0636         byte_pos = 0
0637         byte_remaining = len(bytes)
0638         while byte_remaining > 0:
0639             new_pos = byte_remaining + self.current_pos
0640             if new_pos < self.buffer_size:
0641                 # just put it in our buffer
0642                 self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
0643                 self.current_pos = new_pos
0644                 byte_remaining = 0
0645             else:
0646                 # fill the buffer, send the length then the contents, and start filling again
0647                 space_left = self.buffer_size - self.current_pos
0648                 new_byte_pos = byte_pos + space_left
0649                 self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
0650                 write_int(self.buffer_size, self.wrapped)
0651                 self.wrapped.write(self.buffer)
0652                 byte_remaining -= space_left
0653                 byte_pos = new_byte_pos
0654                 self.current_pos = 0
0655 
0656     def close(self):
0657         # if there is anything left in the buffer, write it out first
0658         if self.current_pos > 0:
0659             write_int(self.current_pos, self.wrapped)
0660             self.wrapped.write(self.buffer[:self.current_pos])
0661         # -1 length indicates to the receiving end that we're done.
0662         write_int(-1, self.wrapped)
0663         self.wrapped.close()
0664 
0665     @property
0666     def closed(self):
0667         """
0668         Return True if the `wrapped` object has been closed.
0669         NOTE: this property is required by pyarrow to be used as a file-like object in
0670         pyarrow.RecordBatchStreamWriter from ArrowStreamSerializer
0671         """
0672         return self.wrapped.closed
0673 
0674 
0675 if __name__ == '__main__':
0676     import doctest
0677     (failure_count, test_count) = doctest.testmod()
0678     if failure_count:
0679         sys.exit(-1)