Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import random
0019 import time
0020 import tempfile
0021 import unittest
0022 
0023 from pyspark import SparkConf, SparkContext
0024 from pyspark.java_gateway import launch_gateway
0025 from pyspark.serializers import ChunkedStream
0026 
0027 
0028 class BroadcastTest(unittest.TestCase):
0029 
0030     def tearDown(self):
0031         if getattr(self, "sc", None) is not None:
0032             self.sc.stop()
0033             self.sc = None
0034 
0035     def _test_encryption_helper(self, vs):
0036         """
0037         Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
0038         value is the same when it's read in the executors.  Also makes sure there are no task
0039         failures.
0040         """
0041         bs = [self.sc.broadcast(value=v) for v in vs]
0042         exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
0043         for ev in exec_values:
0044             self.assertEqual(ev, vs)
0045         # make sure there are no task failures
0046         status = self.sc.statusTracker()
0047         for jid in status.getJobIdsForGroup():
0048             for sid in status.getJobInfo(jid).stageIds:
0049                 stage_info = status.getStageInfo(sid)
0050                 self.assertEqual(0, stage_info.numFailedTasks)
0051 
0052     def _test_multiple_broadcasts(self, *extra_confs):
0053         """
0054         Test broadcast variables make it OK to the executors.  Tests multiple broadcast variables,
0055         and also multiple jobs.
0056         """
0057         conf = SparkConf()
0058         for key, value in extra_confs:
0059             conf.set(key, value)
0060         conf.setMaster("local-cluster[2,1,1024]")
0061         self.sc = SparkContext(conf=conf)
0062         self._test_encryption_helper([5])
0063         self._test_encryption_helper([5, 10, 20])
0064 
0065     def test_broadcast_with_encryption(self):
0066         self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
0067 
0068     def test_broadcast_no_encryption(self):
0069         self._test_multiple_broadcasts()
0070 
0071     def _test_broadcast_on_driver(self, *extra_confs):
0072         conf = SparkConf()
0073         for key, value in extra_confs:
0074             conf.set(key, value)
0075         conf.setMaster("local-cluster[2,1,1024]")
0076         self.sc = SparkContext(conf=conf)
0077         bs = self.sc.broadcast(value=5)
0078         self.assertEqual(5, bs.value)
0079 
0080     def test_broadcast_value_driver_no_encryption(self):
0081         self._test_broadcast_on_driver()
0082 
0083     def test_broadcast_value_driver_encryption(self):
0084         self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
0085 
0086     def test_broadcast_value_against_gc(self):
0087         # Test broadcast value against gc.
0088         conf = SparkConf()
0089         conf.setMaster("local[1,1]")
0090         conf.set("spark.memory.fraction", "0.0001")
0091         self.sc = SparkContext(conf=conf)
0092         b = self.sc.broadcast([100])
0093         try:
0094             res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
0095             self.assertEqual([0], res)
0096             self.sc._jvm.java.lang.System.gc()
0097             time.sleep(5)
0098             res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
0099             self.assertEqual([100], res)
0100         finally:
0101             b.destroy()
0102 
0103 
0104 class BroadcastFrameProtocolTest(unittest.TestCase):
0105 
0106     @classmethod
0107     def setUpClass(cls):
0108         gateway = launch_gateway(SparkConf())
0109         cls._jvm = gateway.jvm
0110         cls.longMessage = True
0111         random.seed(42)
0112 
0113     def _test_chunked_stream(self, data, py_buf_size):
0114         # write data using the chunked protocol from python.
0115         chunked_file = tempfile.NamedTemporaryFile(delete=False)
0116         dechunked_file = tempfile.NamedTemporaryFile(delete=False)
0117         dechunked_file.close()
0118         try:
0119             out = ChunkedStream(chunked_file, py_buf_size)
0120             out.write(data)
0121             out.close()
0122             # now try to read it in java
0123             jin = self._jvm.java.io.FileInputStream(chunked_file.name)
0124             jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
0125             self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
0126             # java should have decoded it back to the original data
0127             self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
0128             with open(dechunked_file.name, "rb") as f:
0129                 byte = f.read(1)
0130                 idx = 0
0131                 while byte:
0132                     self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
0133                     byte = f.read(1)
0134                     idx += 1
0135         finally:
0136             os.unlink(chunked_file.name)
0137             os.unlink(dechunked_file.name)
0138 
0139     def test_chunked_stream(self):
0140         def random_bytes(n):
0141             return bytearray(random.getrandbits(8) for _ in range(n))
0142         for data_length in [1, 10, 100, 10000]:
0143             for buffer_length in [1, 2, 5, 8192]:
0144                 self._test_chunked_stream(random_bytes(data_length), buffer_length)
0145 
0146 
0147 if __name__ == '__main__':
0148     from pyspark.tests.test_broadcast import *
0149 
0150     try:
0151         import xmlrunner
0152         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0153     except ImportError:
0154         testRunner = None
0155     unittest.main(testRunner=testRunner, verbosity=2)