0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import random
0019 import time
0020 import tempfile
0021 import unittest
0022
0023 from pyspark import SparkConf, SparkContext
0024 from pyspark.java_gateway import launch_gateway
0025 from pyspark.serializers import ChunkedStream
0026
0027
0028 class BroadcastTest(unittest.TestCase):
0029
0030 def tearDown(self):
0031 if getattr(self, "sc", None) is not None:
0032 self.sc.stop()
0033 self.sc = None
0034
0035 def _test_encryption_helper(self, vs):
0036 """
0037 Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
0038 value is the same when it's read in the executors. Also makes sure there are no task
0039 failures.
0040 """
0041 bs = [self.sc.broadcast(value=v) for v in vs]
0042 exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
0043 for ev in exec_values:
0044 self.assertEqual(ev, vs)
0045
0046 status = self.sc.statusTracker()
0047 for jid in status.getJobIdsForGroup():
0048 for sid in status.getJobInfo(jid).stageIds:
0049 stage_info = status.getStageInfo(sid)
0050 self.assertEqual(0, stage_info.numFailedTasks)
0051
0052 def _test_multiple_broadcasts(self, *extra_confs):
0053 """
0054 Test broadcast variables make it OK to the executors. Tests multiple broadcast variables,
0055 and also multiple jobs.
0056 """
0057 conf = SparkConf()
0058 for key, value in extra_confs:
0059 conf.set(key, value)
0060 conf.setMaster("local-cluster[2,1,1024]")
0061 self.sc = SparkContext(conf=conf)
0062 self._test_encryption_helper([5])
0063 self._test_encryption_helper([5, 10, 20])
0064
0065 def test_broadcast_with_encryption(self):
0066 self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
0067
0068 def test_broadcast_no_encryption(self):
0069 self._test_multiple_broadcasts()
0070
0071 def _test_broadcast_on_driver(self, *extra_confs):
0072 conf = SparkConf()
0073 for key, value in extra_confs:
0074 conf.set(key, value)
0075 conf.setMaster("local-cluster[2,1,1024]")
0076 self.sc = SparkContext(conf=conf)
0077 bs = self.sc.broadcast(value=5)
0078 self.assertEqual(5, bs.value)
0079
0080 def test_broadcast_value_driver_no_encryption(self):
0081 self._test_broadcast_on_driver()
0082
0083 def test_broadcast_value_driver_encryption(self):
0084 self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
0085
0086 def test_broadcast_value_against_gc(self):
0087
0088 conf = SparkConf()
0089 conf.setMaster("local[1,1]")
0090 conf.set("spark.memory.fraction", "0.0001")
0091 self.sc = SparkContext(conf=conf)
0092 b = self.sc.broadcast([100])
0093 try:
0094 res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
0095 self.assertEqual([0], res)
0096 self.sc._jvm.java.lang.System.gc()
0097 time.sleep(5)
0098 res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
0099 self.assertEqual([100], res)
0100 finally:
0101 b.destroy()
0102
0103
0104 class BroadcastFrameProtocolTest(unittest.TestCase):
0105
0106 @classmethod
0107 def setUpClass(cls):
0108 gateway = launch_gateway(SparkConf())
0109 cls._jvm = gateway.jvm
0110 cls.longMessage = True
0111 random.seed(42)
0112
0113 def _test_chunked_stream(self, data, py_buf_size):
0114
0115 chunked_file = tempfile.NamedTemporaryFile(delete=False)
0116 dechunked_file = tempfile.NamedTemporaryFile(delete=False)
0117 dechunked_file.close()
0118 try:
0119 out = ChunkedStream(chunked_file, py_buf_size)
0120 out.write(data)
0121 out.close()
0122
0123 jin = self._jvm.java.io.FileInputStream(chunked_file.name)
0124 jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
0125 self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
0126
0127 self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
0128 with open(dechunked_file.name, "rb") as f:
0129 byte = f.read(1)
0130 idx = 0
0131 while byte:
0132 self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
0133 byte = f.read(1)
0134 idx += 1
0135 finally:
0136 os.unlink(chunked_file.name)
0137 os.unlink(dechunked_file.name)
0138
0139 def test_chunked_stream(self):
0140 def random_bytes(n):
0141 return bytearray(random.getrandbits(8) for _ in range(n))
0142 for data_length in [1, 10, 100, 10000]:
0143 for buffer_length in [1, 2, 5, 8192]:
0144 self._test_chunked_stream(random_bytes(data_length), buffer_length)
0145
0146
0147 if __name__ == '__main__':
0148 from pyspark.tests.test_broadcast import *
0149
0150 try:
0151 import xmlrunner
0152 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0153 except ImportError:
0154 testRunner = None
0155 unittest.main(testRunner=testRunner, verbosity=2)