Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import math
0018 import sys
0019 import unittest
0020 
0021 from pyspark import serializers
0022 from pyspark.serializers import *
0023 from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \
0024     AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \
0025     FlattenedValuesSerializer, CartesianDeserializer
0026 from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \
0027     have_numpy, have_scipy
0028 
0029 
0030 class SerializationTestCase(unittest.TestCase):
0031 
0032     def test_namedtuple(self):
0033         from collections import namedtuple
0034         from pickle import dumps, loads
0035         P = namedtuple("P", "x y")
0036         p1 = P(1, 3)
0037         p2 = loads(dumps(p1, 2))
0038         self.assertEqual(p1, p2)
0039 
0040         from pyspark.cloudpickle import dumps
0041         P2 = loads(dumps(P))
0042         p3 = P2(1, 3)
0043         self.assertEqual(p1, p3)
0044 
0045     def test_itemgetter(self):
0046         from operator import itemgetter
0047         ser = CloudPickleSerializer()
0048         d = range(10)
0049         getter = itemgetter(1)
0050         getter2 = ser.loads(ser.dumps(getter))
0051         self.assertEqual(getter(d), getter2(d))
0052 
0053         getter = itemgetter(0, 3)
0054         getter2 = ser.loads(ser.dumps(getter))
0055         self.assertEqual(getter(d), getter2(d))
0056 
0057     def test_function_module_name(self):
0058         ser = CloudPickleSerializer()
0059         func = lambda x: x
0060         func2 = ser.loads(ser.dumps(func))
0061         self.assertEqual(func.__module__, func2.__module__)
0062 
0063     def test_attrgetter(self):
0064         from operator import attrgetter
0065         ser = CloudPickleSerializer()
0066 
0067         class C(object):
0068             def __getattr__(self, item):
0069                 return item
0070         d = C()
0071         getter = attrgetter("a")
0072         getter2 = ser.loads(ser.dumps(getter))
0073         self.assertEqual(getter(d), getter2(d))
0074         getter = attrgetter("a", "b")
0075         getter2 = ser.loads(ser.dumps(getter))
0076         self.assertEqual(getter(d), getter2(d))
0077 
0078         d.e = C()
0079         getter = attrgetter("e.a")
0080         getter2 = ser.loads(ser.dumps(getter))
0081         self.assertEqual(getter(d), getter2(d))
0082         getter = attrgetter("e.a", "e.b")
0083         getter2 = ser.loads(ser.dumps(getter))
0084         self.assertEqual(getter(d), getter2(d))
0085 
0086     # Regression test for SPARK-3415
0087     def test_pickling_file_handles(self):
0088         # to be corrected with SPARK-11160
0089         try:
0090             import xmlrunner
0091         except ImportError:
0092             ser = CloudPickleSerializer()
0093             out1 = sys.stderr
0094             out2 = ser.loads(ser.dumps(out1))
0095             self.assertEqual(out1, out2)
0096 
0097     def test_func_globals(self):
0098 
0099         class Unpicklable(object):
0100             def __reduce__(self):
0101                 raise Exception("not picklable")
0102 
0103         global exit
0104         exit = Unpicklable()
0105 
0106         ser = CloudPickleSerializer()
0107         self.assertRaises(Exception, lambda: ser.dumps(exit))
0108 
0109         def foo():
0110             sys.exit(0)
0111 
0112         self.assertTrue("exit" in foo.__code__.co_names)
0113         ser.dumps(foo)
0114 
0115     def test_compressed_serializer(self):
0116         ser = CompressedSerializer(PickleSerializer())
0117         try:
0118             from StringIO import StringIO
0119         except ImportError:
0120             from io import BytesIO as StringIO
0121         io = StringIO()
0122         ser.dump_stream(["abc", u"123", range(5)], io)
0123         io.seek(0)
0124         self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
0125         ser.dump_stream(range(1000), io)
0126         io.seek(0)
0127         self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
0128         io.close()
0129 
0130     def test_hash_serializer(self):
0131         hash(NoOpSerializer())
0132         hash(UTF8Deserializer())
0133         hash(PickleSerializer())
0134         hash(MarshalSerializer())
0135         hash(AutoSerializer())
0136         hash(BatchedSerializer(PickleSerializer()))
0137         hash(AutoBatchedSerializer(MarshalSerializer()))
0138         hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
0139         hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
0140         hash(CompressedSerializer(PickleSerializer()))
0141         hash(FlattenedValuesSerializer(PickleSerializer()))
0142 
0143 
0144 @unittest.skipIf(not have_scipy, "SciPy not installed")
0145 class SciPyTests(PySparkTestCase):
0146 
0147     """General PySpark tests that depend on scipy """
0148 
0149     def test_serialize(self):
0150         from scipy.special import gammaln
0151 
0152         x = range(1, 5)
0153         expected = list(map(gammaln, x))
0154         observed = self.sc.parallelize(x).map(gammaln).collect()
0155         self.assertEqual(expected, observed)
0156 
0157 
0158 @unittest.skipIf(not have_numpy, "NumPy not installed")
0159 class NumPyTests(PySparkTestCase):
0160 
0161     """General PySpark tests that depend on numpy """
0162 
0163     def test_statcounter_array(self):
0164         import numpy as np
0165 
0166         x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])])
0167         s = x.stats()
0168         self.assertSequenceEqual([2.0, 2.0], s.mean().tolist())
0169         self.assertSequenceEqual([1.0, 1.0], s.min().tolist())
0170         self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
0171         self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
0172 
0173         stats_dict = s.asDict()
0174         self.assertEqual(3, stats_dict['count'])
0175         self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
0176         self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
0177         self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
0178         self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
0179         self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
0180         self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
0181 
0182         stats_sample_dict = s.asDict(sample=True)
0183         self.assertEqual(3, stats_dict['count'])
0184         self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
0185         self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
0186         self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
0187         self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
0188         self.assertSequenceEqual(
0189             [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
0190         self.assertSequenceEqual(
0191             [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
0192 
0193 
0194 class SerializersTest(unittest.TestCase):
0195 
0196     def test_chunked_stream(self):
0197         original_bytes = bytearray(range(100))
0198         for data_length in [1, 10, 100]:
0199             for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
0200                 dest = ByteArrayOutput()
0201                 stream_out = serializers.ChunkedStream(dest, buffer_length)
0202                 stream_out.write(original_bytes[:data_length])
0203                 stream_out.close()
0204                 num_chunks = int(math.ceil(float(data_length) / buffer_length))
0205                 # length for each chunk, and a final -1 at the very end
0206                 exp_size = (num_chunks + 1) * 4 + data_length
0207                 self.assertEqual(len(dest.buffer), exp_size)
0208                 dest_pos = 0
0209                 data_pos = 0
0210                 for chunk_idx in range(num_chunks):
0211                     chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
0212                     if chunk_idx == num_chunks - 1:
0213                         exp_length = data_length % buffer_length
0214                         if exp_length == 0:
0215                             exp_length = buffer_length
0216                     else:
0217                         exp_length = buffer_length
0218                     self.assertEqual(chunk_length, exp_length)
0219                     dest_pos += 4
0220                     dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
0221                     orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
0222                     self.assertEqual(dest_chunk, orig_chunk)
0223                     dest_pos += chunk_length
0224                     data_pos += chunk_length
0225                 # ends with a -1
0226                 self.assertEqual(dest.buffer[-4:], write_int(-1))
0227 
0228 
0229 if __name__ == "__main__":
0230     from pyspark.tests.test_serializers import *
0231 
0232     try:
0233         import xmlrunner
0234         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0235     except ImportError:
0236         testRunner = None
0237     unittest.main(testRunner=testRunner, verbosity=2)