0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import math
0018 import sys
0019 import unittest
0020
0021 from pyspark import serializers
0022 from pyspark.serializers import *
0023 from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \
0024 AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \
0025 FlattenedValuesSerializer, CartesianDeserializer
0026 from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \
0027 have_numpy, have_scipy
0028
0029
0030 class SerializationTestCase(unittest.TestCase):
0031
0032 def test_namedtuple(self):
0033 from collections import namedtuple
0034 from pickle import dumps, loads
0035 P = namedtuple("P", "x y")
0036 p1 = P(1, 3)
0037 p2 = loads(dumps(p1, 2))
0038 self.assertEqual(p1, p2)
0039
0040 from pyspark.cloudpickle import dumps
0041 P2 = loads(dumps(P))
0042 p3 = P2(1, 3)
0043 self.assertEqual(p1, p3)
0044
0045 def test_itemgetter(self):
0046 from operator import itemgetter
0047 ser = CloudPickleSerializer()
0048 d = range(10)
0049 getter = itemgetter(1)
0050 getter2 = ser.loads(ser.dumps(getter))
0051 self.assertEqual(getter(d), getter2(d))
0052
0053 getter = itemgetter(0, 3)
0054 getter2 = ser.loads(ser.dumps(getter))
0055 self.assertEqual(getter(d), getter2(d))
0056
0057 def test_function_module_name(self):
0058 ser = CloudPickleSerializer()
0059 func = lambda x: x
0060 func2 = ser.loads(ser.dumps(func))
0061 self.assertEqual(func.__module__, func2.__module__)
0062
0063 def test_attrgetter(self):
0064 from operator import attrgetter
0065 ser = CloudPickleSerializer()
0066
0067 class C(object):
0068 def __getattr__(self, item):
0069 return item
0070 d = C()
0071 getter = attrgetter("a")
0072 getter2 = ser.loads(ser.dumps(getter))
0073 self.assertEqual(getter(d), getter2(d))
0074 getter = attrgetter("a", "b")
0075 getter2 = ser.loads(ser.dumps(getter))
0076 self.assertEqual(getter(d), getter2(d))
0077
0078 d.e = C()
0079 getter = attrgetter("e.a")
0080 getter2 = ser.loads(ser.dumps(getter))
0081 self.assertEqual(getter(d), getter2(d))
0082 getter = attrgetter("e.a", "e.b")
0083 getter2 = ser.loads(ser.dumps(getter))
0084 self.assertEqual(getter(d), getter2(d))
0085
0086
0087 def test_pickling_file_handles(self):
0088
0089 try:
0090 import xmlrunner
0091 except ImportError:
0092 ser = CloudPickleSerializer()
0093 out1 = sys.stderr
0094 out2 = ser.loads(ser.dumps(out1))
0095 self.assertEqual(out1, out2)
0096
0097 def test_func_globals(self):
0098
0099 class Unpicklable(object):
0100 def __reduce__(self):
0101 raise Exception("not picklable")
0102
0103 global exit
0104 exit = Unpicklable()
0105
0106 ser = CloudPickleSerializer()
0107 self.assertRaises(Exception, lambda: ser.dumps(exit))
0108
0109 def foo():
0110 sys.exit(0)
0111
0112 self.assertTrue("exit" in foo.__code__.co_names)
0113 ser.dumps(foo)
0114
0115 def test_compressed_serializer(self):
0116 ser = CompressedSerializer(PickleSerializer())
0117 try:
0118 from StringIO import StringIO
0119 except ImportError:
0120 from io import BytesIO as StringIO
0121 io = StringIO()
0122 ser.dump_stream(["abc", u"123", range(5)], io)
0123 io.seek(0)
0124 self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
0125 ser.dump_stream(range(1000), io)
0126 io.seek(0)
0127 self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
0128 io.close()
0129
0130 def test_hash_serializer(self):
0131 hash(NoOpSerializer())
0132 hash(UTF8Deserializer())
0133 hash(PickleSerializer())
0134 hash(MarshalSerializer())
0135 hash(AutoSerializer())
0136 hash(BatchedSerializer(PickleSerializer()))
0137 hash(AutoBatchedSerializer(MarshalSerializer()))
0138 hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
0139 hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
0140 hash(CompressedSerializer(PickleSerializer()))
0141 hash(FlattenedValuesSerializer(PickleSerializer()))
0142
0143
0144 @unittest.skipIf(not have_scipy, "SciPy not installed")
0145 class SciPyTests(PySparkTestCase):
0146
0147 """General PySpark tests that depend on scipy """
0148
0149 def test_serialize(self):
0150 from scipy.special import gammaln
0151
0152 x = range(1, 5)
0153 expected = list(map(gammaln, x))
0154 observed = self.sc.parallelize(x).map(gammaln).collect()
0155 self.assertEqual(expected, observed)
0156
0157
0158 @unittest.skipIf(not have_numpy, "NumPy not installed")
0159 class NumPyTests(PySparkTestCase):
0160
0161 """General PySpark tests that depend on numpy """
0162
0163 def test_statcounter_array(self):
0164 import numpy as np
0165
0166 x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])])
0167 s = x.stats()
0168 self.assertSequenceEqual([2.0, 2.0], s.mean().tolist())
0169 self.assertSequenceEqual([1.0, 1.0], s.min().tolist())
0170 self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
0171 self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
0172
0173 stats_dict = s.asDict()
0174 self.assertEqual(3, stats_dict['count'])
0175 self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
0176 self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
0177 self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
0178 self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
0179 self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
0180 self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
0181
0182 stats_sample_dict = s.asDict(sample=True)
0183 self.assertEqual(3, stats_dict['count'])
0184 self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
0185 self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
0186 self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
0187 self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
0188 self.assertSequenceEqual(
0189 [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
0190 self.assertSequenceEqual(
0191 [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
0192
0193
0194 class SerializersTest(unittest.TestCase):
0195
0196 def test_chunked_stream(self):
0197 original_bytes = bytearray(range(100))
0198 for data_length in [1, 10, 100]:
0199 for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
0200 dest = ByteArrayOutput()
0201 stream_out = serializers.ChunkedStream(dest, buffer_length)
0202 stream_out.write(original_bytes[:data_length])
0203 stream_out.close()
0204 num_chunks = int(math.ceil(float(data_length) / buffer_length))
0205
0206 exp_size = (num_chunks + 1) * 4 + data_length
0207 self.assertEqual(len(dest.buffer), exp_size)
0208 dest_pos = 0
0209 data_pos = 0
0210 for chunk_idx in range(num_chunks):
0211 chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
0212 if chunk_idx == num_chunks - 1:
0213 exp_length = data_length % buffer_length
0214 if exp_length == 0:
0215 exp_length = buffer_length
0216 else:
0217 exp_length = buffer_length
0218 self.assertEqual(chunk_length, exp_length)
0219 dest_pos += 4
0220 dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
0221 orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
0222 self.assertEqual(dest_chunk, orig_chunk)
0223 dest_pos += chunk_length
0224 data_pos += chunk_length
0225
0226 self.assertEqual(dest.buffer[-4:], write_int(-1))
0227
0228
0229 if __name__ == "__main__":
0230 from pyspark.tests.test_serializers import *
0231
0232 try:
0233 import xmlrunner
0234 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0235 except ImportError:
0236 testRunner = None
0237 unittest.main(testRunner=testRunner, verbosity=2)