0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import shutil
0019 import stat
0020 import tempfile
0021 import threading
0022 import time
0023 import unittest
0024 from collections import namedtuple
0025
0026 from pyspark import SparkConf, SparkFiles, SparkContext
0027 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
0028
0029
0030 class CheckpointTests(ReusedPySparkTestCase):
0031
0032 def setUp(self):
0033 self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
0034 os.unlink(self.checkpointDir.name)
0035 self.sc.setCheckpointDir(self.checkpointDir.name)
0036
0037 def tearDown(self):
0038 shutil.rmtree(self.checkpointDir.name)
0039
0040 def test_basic_checkpointing(self):
0041 parCollection = self.sc.parallelize([1, 2, 3, 4])
0042 flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
0043
0044 self.assertFalse(flatMappedRDD.isCheckpointed())
0045 self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
0046
0047 flatMappedRDD.checkpoint()
0048 result = flatMappedRDD.collect()
0049 time.sleep(1)
0050 self.assertTrue(flatMappedRDD.isCheckpointed())
0051 self.assertEqual(flatMappedRDD.collect(), result)
0052 self.assertEqual("file:" + self.checkpointDir.name,
0053 os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
0054
0055 def test_checkpoint_and_restore(self):
0056 parCollection = self.sc.parallelize([1, 2, 3, 4])
0057 flatMappedRDD = parCollection.flatMap(lambda x: [x])
0058
0059 self.assertFalse(flatMappedRDD.isCheckpointed())
0060 self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
0061
0062 flatMappedRDD.checkpoint()
0063 flatMappedRDD.count()
0064 time.sleep(1)
0065
0066 self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
0067 recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
0068 flatMappedRDD._jrdd_deserializer)
0069 self.assertEqual([1, 2, 3, 4], recovered.collect())
0070
0071
0072 class LocalCheckpointTests(ReusedPySparkTestCase):
0073
0074 def test_basic_localcheckpointing(self):
0075 parCollection = self.sc.parallelize([1, 2, 3, 4])
0076 flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
0077
0078 self.assertFalse(flatMappedRDD.isCheckpointed())
0079 self.assertFalse(flatMappedRDD.isLocallyCheckpointed())
0080
0081 flatMappedRDD.localCheckpoint()
0082 result = flatMappedRDD.collect()
0083 time.sleep(1)
0084 self.assertTrue(flatMappedRDD.isCheckpointed())
0085 self.assertTrue(flatMappedRDD.isLocallyCheckpointed())
0086 self.assertEqual(flatMappedRDD.collect(), result)
0087
0088
0089 class AddFileTests(PySparkTestCase):
0090
0091 def test_add_py_file(self):
0092
0093
0094
0095 def func(x):
0096 from userlibrary import UserClass
0097 return UserClass().hello()
0098 with QuietTest(self.sc):
0099 self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
0100
0101
0102 path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
0103 self.sc.addPyFile(path)
0104 res = self.sc.parallelize(range(2)).map(func).first()
0105 self.assertEqual("Hello World!", res)
0106
0107 def test_add_file_locally(self):
0108 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0109 self.sc.addFile(path)
0110 download_path = SparkFiles.get("hello.txt")
0111 self.assertNotEqual(path, download_path)
0112 with open(download_path) as test_file:
0113 self.assertEqual("Hello World!\n", test_file.readline())
0114
0115 def test_add_file_recursively_locally(self):
0116 path = os.path.join(SPARK_HOME, "python/test_support/hello")
0117 self.sc.addFile(path, True)
0118 download_path = SparkFiles.get("hello")
0119 self.assertNotEqual(path, download_path)
0120 with open(download_path + "/hello.txt") as test_file:
0121 self.assertEqual("Hello World!\n", test_file.readline())
0122 with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
0123 self.assertEqual("Sub Hello World!\n", test_file.readline())
0124
0125 def test_add_py_file_locally(self):
0126
0127
0128 def func():
0129 from userlibrary import UserClass
0130 self.assertRaises(ImportError, func)
0131 path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
0132 self.sc.addPyFile(path)
0133 from userlibrary import UserClass
0134 self.assertEqual("Hello World!", UserClass().hello())
0135
0136 def test_add_egg_file_locally(self):
0137
0138
0139 def func():
0140 from userlib import UserClass
0141 self.assertRaises(ImportError, func)
0142 path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
0143 self.sc.addPyFile(path)
0144 from userlib import UserClass
0145 self.assertEqual("Hello World from inside a package!", UserClass().hello())
0146
0147 def test_overwrite_system_module(self):
0148 self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
0149
0150 import SimpleHTTPServer
0151 self.assertEqual("My Server", SimpleHTTPServer.__name__)
0152
0153 def func(x):
0154 import SimpleHTTPServer
0155 return SimpleHTTPServer.__name__
0156
0157 self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
0158
0159
0160 class ContextTests(unittest.TestCase):
0161
0162 def test_failed_sparkcontext_creation(self):
0163
0164 self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
0165
0166 def test_get_or_create(self):
0167 with SparkContext.getOrCreate() as sc:
0168 self.assertTrue(SparkContext.getOrCreate() is sc)
0169
0170 def test_parallelize_eager_cleanup(self):
0171 with SparkContext() as sc:
0172 temp_files = os.listdir(sc._temp_dir)
0173 rdd = sc.parallelize([0, 1, 2])
0174 post_parallalize_temp_files = os.listdir(sc._temp_dir)
0175 self.assertEqual(temp_files, post_parallalize_temp_files)
0176
0177 def test_set_conf(self):
0178
0179
0180 sc = SparkContext()
0181 sc._conf.set("spark.test.SPARK16224", "SPARK16224")
0182 self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224")
0183 sc.stop()
0184
0185 def test_stop(self):
0186 sc = SparkContext()
0187 self.assertNotEqual(SparkContext._active_spark_context, None)
0188 sc.stop()
0189 self.assertEqual(SparkContext._active_spark_context, None)
0190
0191 def test_with(self):
0192 with SparkContext() as sc:
0193 self.assertNotEqual(SparkContext._active_spark_context, None)
0194 self.assertEqual(SparkContext._active_spark_context, None)
0195
0196 def test_with_exception(self):
0197 try:
0198 with SparkContext() as sc:
0199 self.assertNotEqual(SparkContext._active_spark_context, None)
0200 raise Exception()
0201 except:
0202 pass
0203 self.assertEqual(SparkContext._active_spark_context, None)
0204
0205 def test_with_stop(self):
0206 with SparkContext() as sc:
0207 self.assertNotEqual(SparkContext._active_spark_context, None)
0208 sc.stop()
0209 self.assertEqual(SparkContext._active_spark_context, None)
0210
0211 def test_progress_api(self):
0212 with SparkContext() as sc:
0213 sc.setJobGroup('test_progress_api', '', True)
0214 rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
0215
0216 def run():
0217
0218
0219 if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
0220 sc.setJobGroup('test_progress_api', '', True)
0221 try:
0222 rdd.count()
0223 except Exception:
0224 pass
0225 t = threading.Thread(target=run)
0226 t.daemon = True
0227 t.start()
0228
0229 time.sleep(1)
0230
0231 tracker = sc.statusTracker()
0232 jobIds = tracker.getJobIdsForGroup('test_progress_api')
0233 self.assertEqual(1, len(jobIds))
0234 job = tracker.getJobInfo(jobIds[0])
0235 self.assertEqual(1, len(job.stageIds))
0236 stage = tracker.getStageInfo(job.stageIds[0])
0237 self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
0238
0239 sc.cancelAllJobs()
0240 t.join()
0241
0242 time.sleep(1)
0243
0244 job = tracker.getJobInfo(jobIds[0])
0245 self.assertEqual('FAILED', job.status)
0246 self.assertEqual([], tracker.getActiveJobsIds())
0247 self.assertEqual([], tracker.getActiveStageIds())
0248
0249 sc.stop()
0250
0251 def test_startTime(self):
0252 with SparkContext() as sc:
0253 self.assertGreater(sc.startTime, 0)
0254
0255 def test_forbid_insecure_gateway(self):
0256
0257
0258 parameters = namedtuple('MockGatewayParameters', 'auth_token')(None)
0259 mock_insecure_gateway = namedtuple('MockJavaGateway', 'gateway_parameters')(parameters)
0260 with self.assertRaises(ValueError) as context:
0261 SparkContext(gateway=mock_insecure_gateway)
0262 self.assertIn("insecure Py4j gateway", str(context.exception))
0263
0264 def test_resources(self):
0265 """Test the resources are empty by default."""
0266 with SparkContext() as sc:
0267 resources = sc.resources
0268 self.assertEqual(len(resources), 0)
0269
0270
0271 class ContextTestsWithResources(unittest.TestCase):
0272
0273 def setUp(self):
0274 class_name = self.__class__.__name__
0275 self.tempFile = tempfile.NamedTemporaryFile(delete=False)
0276 self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
0277 self.tempFile.close()
0278 os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
0279 stat.S_IROTH | stat.S_IXOTH)
0280 conf = SparkConf().set("spark.test.home", SPARK_HOME)
0281 conf = conf.set("spark.driver.resource.gpu.amount", "1")
0282 conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
0283 self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
0284
0285 def test_resources(self):
0286 """Test the resources are available."""
0287 resources = self.sc.resources
0288 self.assertEqual(len(resources), 1)
0289 self.assertTrue('gpu' in resources)
0290 self.assertEqual(resources['gpu'].name, 'gpu')
0291 self.assertEqual(resources['gpu'].addresses, ['0'])
0292
0293 def tearDown(self):
0294 os.unlink(self.tempFile.name)
0295 self.sc.stop()
0296
0297
0298 if __name__ == "__main__":
0299 from pyspark.tests.test_context import *
0300
0301 try:
0302 import xmlrunner
0303 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0304 except ImportError:
0305 testRunner = None
0306 unittest.main(testRunner=testRunner, verbosity=2)