Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import shutil
0019 import stat
0020 import tempfile
0021 import threading
0022 import time
0023 import unittest
0024 from collections import namedtuple
0025 
0026 from pyspark import SparkConf, SparkFiles, SparkContext
0027 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
0028 
0029 
0030 class CheckpointTests(ReusedPySparkTestCase):
0031 
0032     def setUp(self):
0033         self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
0034         os.unlink(self.checkpointDir.name)
0035         self.sc.setCheckpointDir(self.checkpointDir.name)
0036 
0037     def tearDown(self):
0038         shutil.rmtree(self.checkpointDir.name)
0039 
0040     def test_basic_checkpointing(self):
0041         parCollection = self.sc.parallelize([1, 2, 3, 4])
0042         flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
0043 
0044         self.assertFalse(flatMappedRDD.isCheckpointed())
0045         self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
0046 
0047         flatMappedRDD.checkpoint()
0048         result = flatMappedRDD.collect()
0049         time.sleep(1)  # 1 second
0050         self.assertTrue(flatMappedRDD.isCheckpointed())
0051         self.assertEqual(flatMappedRDD.collect(), result)
0052         self.assertEqual("file:" + self.checkpointDir.name,
0053                          os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
0054 
0055     def test_checkpoint_and_restore(self):
0056         parCollection = self.sc.parallelize([1, 2, 3, 4])
0057         flatMappedRDD = parCollection.flatMap(lambda x: [x])
0058 
0059         self.assertFalse(flatMappedRDD.isCheckpointed())
0060         self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
0061 
0062         flatMappedRDD.checkpoint()
0063         flatMappedRDD.count()  # forces a checkpoint to be computed
0064         time.sleep(1)  # 1 second
0065 
0066         self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
0067         recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
0068                                             flatMappedRDD._jrdd_deserializer)
0069         self.assertEqual([1, 2, 3, 4], recovered.collect())
0070 
0071 
0072 class LocalCheckpointTests(ReusedPySparkTestCase):
0073 
0074     def test_basic_localcheckpointing(self):
0075         parCollection = self.sc.parallelize([1, 2, 3, 4])
0076         flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
0077 
0078         self.assertFalse(flatMappedRDD.isCheckpointed())
0079         self.assertFalse(flatMappedRDD.isLocallyCheckpointed())
0080 
0081         flatMappedRDD.localCheckpoint()
0082         result = flatMappedRDD.collect()
0083         time.sleep(1)  # 1 second
0084         self.assertTrue(flatMappedRDD.isCheckpointed())
0085         self.assertTrue(flatMappedRDD.isLocallyCheckpointed())
0086         self.assertEqual(flatMappedRDD.collect(), result)
0087 
0088 
0089 class AddFileTests(PySparkTestCase):
0090 
0091     def test_add_py_file(self):
0092         # To ensure that we're actually testing addPyFile's effects, check that
0093         # this job fails due to `userlibrary` not being on the Python path:
0094         # disable logging in log4j temporarily
0095         def func(x):
0096             from userlibrary import UserClass
0097             return UserClass().hello()
0098         with QuietTest(self.sc):
0099             self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
0100 
0101         # Add the file, so the job should now succeed:
0102         path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
0103         self.sc.addPyFile(path)
0104         res = self.sc.parallelize(range(2)).map(func).first()
0105         self.assertEqual("Hello World!", res)
0106 
0107     def test_add_file_locally(self):
0108         path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0109         self.sc.addFile(path)
0110         download_path = SparkFiles.get("hello.txt")
0111         self.assertNotEqual(path, download_path)
0112         with open(download_path) as test_file:
0113             self.assertEqual("Hello World!\n", test_file.readline())
0114 
0115     def test_add_file_recursively_locally(self):
0116         path = os.path.join(SPARK_HOME, "python/test_support/hello")
0117         self.sc.addFile(path, True)
0118         download_path = SparkFiles.get("hello")
0119         self.assertNotEqual(path, download_path)
0120         with open(download_path + "/hello.txt") as test_file:
0121             self.assertEqual("Hello World!\n", test_file.readline())
0122         with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
0123             self.assertEqual("Sub Hello World!\n", test_file.readline())
0124 
0125     def test_add_py_file_locally(self):
0126         # To ensure that we're actually testing addPyFile's effects, check that
0127         # this fails due to `userlibrary` not being on the Python path:
0128         def func():
0129             from userlibrary import UserClass
0130         self.assertRaises(ImportError, func)
0131         path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
0132         self.sc.addPyFile(path)
0133         from userlibrary import UserClass
0134         self.assertEqual("Hello World!", UserClass().hello())
0135 
0136     def test_add_egg_file_locally(self):
0137         # To ensure that we're actually testing addPyFile's effects, check that
0138         # this fails due to `userlibrary` not being on the Python path:
0139         def func():
0140             from userlib import UserClass
0141         self.assertRaises(ImportError, func)
0142         path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
0143         self.sc.addPyFile(path)
0144         from userlib import UserClass
0145         self.assertEqual("Hello World from inside a package!", UserClass().hello())
0146 
0147     def test_overwrite_system_module(self):
0148         self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
0149 
0150         import SimpleHTTPServer
0151         self.assertEqual("My Server", SimpleHTTPServer.__name__)
0152 
0153         def func(x):
0154             import SimpleHTTPServer
0155             return SimpleHTTPServer.__name__
0156 
0157         self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
0158 
0159 
0160 class ContextTests(unittest.TestCase):
0161 
0162     def test_failed_sparkcontext_creation(self):
0163         # Regression test for SPARK-1550
0164         self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
0165 
0166     def test_get_or_create(self):
0167         with SparkContext.getOrCreate() as sc:
0168             self.assertTrue(SparkContext.getOrCreate() is sc)
0169 
0170     def test_parallelize_eager_cleanup(self):
0171         with SparkContext() as sc:
0172             temp_files = os.listdir(sc._temp_dir)
0173             rdd = sc.parallelize([0, 1, 2])
0174             post_parallalize_temp_files = os.listdir(sc._temp_dir)
0175             self.assertEqual(temp_files, post_parallalize_temp_files)
0176 
0177     def test_set_conf(self):
0178         # This is for an internal use case. When there is an existing SparkContext,
0179         # SparkSession's builder needs to set configs into SparkContext's conf.
0180         sc = SparkContext()
0181         sc._conf.set("spark.test.SPARK16224", "SPARK16224")
0182         self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224")
0183         sc.stop()
0184 
0185     def test_stop(self):
0186         sc = SparkContext()
0187         self.assertNotEqual(SparkContext._active_spark_context, None)
0188         sc.stop()
0189         self.assertEqual(SparkContext._active_spark_context, None)
0190 
0191     def test_with(self):
0192         with SparkContext() as sc:
0193             self.assertNotEqual(SparkContext._active_spark_context, None)
0194         self.assertEqual(SparkContext._active_spark_context, None)
0195 
0196     def test_with_exception(self):
0197         try:
0198             with SparkContext() as sc:
0199                 self.assertNotEqual(SparkContext._active_spark_context, None)
0200                 raise Exception()
0201         except:
0202             pass
0203         self.assertEqual(SparkContext._active_spark_context, None)
0204 
0205     def test_with_stop(self):
0206         with SparkContext() as sc:
0207             self.assertNotEqual(SparkContext._active_spark_context, None)
0208             sc.stop()
0209         self.assertEqual(SparkContext._active_spark_context, None)
0210 
0211     def test_progress_api(self):
0212         with SparkContext() as sc:
0213             sc.setJobGroup('test_progress_api', '', True)
0214             rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
0215 
0216             def run():
0217                 # When thread is pinned, job group should be set for each thread for now.
0218                 # Local properties seem not being inherited like Scala side does.
0219                 if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
0220                     sc.setJobGroup('test_progress_api', '', True)
0221                 try:
0222                     rdd.count()
0223                 except Exception:
0224                     pass
0225             t = threading.Thread(target=run)
0226             t.daemon = True
0227             t.start()
0228             # wait for scheduler to start
0229             time.sleep(1)
0230 
0231             tracker = sc.statusTracker()
0232             jobIds = tracker.getJobIdsForGroup('test_progress_api')
0233             self.assertEqual(1, len(jobIds))
0234             job = tracker.getJobInfo(jobIds[0])
0235             self.assertEqual(1, len(job.stageIds))
0236             stage = tracker.getStageInfo(job.stageIds[0])
0237             self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
0238 
0239             sc.cancelAllJobs()
0240             t.join()
0241             # wait for event listener to update the status
0242             time.sleep(1)
0243 
0244             job = tracker.getJobInfo(jobIds[0])
0245             self.assertEqual('FAILED', job.status)
0246             self.assertEqual([], tracker.getActiveJobsIds())
0247             self.assertEqual([], tracker.getActiveStageIds())
0248 
0249             sc.stop()
0250 
0251     def test_startTime(self):
0252         with SparkContext() as sc:
0253             self.assertGreater(sc.startTime, 0)
0254 
0255     def test_forbid_insecure_gateway(self):
0256         # Fail immediately if you try to create a SparkContext
0257         # with an insecure gateway
0258         parameters = namedtuple('MockGatewayParameters', 'auth_token')(None)
0259         mock_insecure_gateway = namedtuple('MockJavaGateway', 'gateway_parameters')(parameters)
0260         with self.assertRaises(ValueError) as context:
0261             SparkContext(gateway=mock_insecure_gateway)
0262         self.assertIn("insecure Py4j gateway", str(context.exception))
0263 
0264     def test_resources(self):
0265         """Test the resources are empty by default."""
0266         with SparkContext() as sc:
0267             resources = sc.resources
0268             self.assertEqual(len(resources), 0)
0269 
0270 
0271 class ContextTestsWithResources(unittest.TestCase):
0272 
0273     def setUp(self):
0274         class_name = self.__class__.__name__
0275         self.tempFile = tempfile.NamedTemporaryFile(delete=False)
0276         self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
0277         self.tempFile.close()
0278         os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
0279                  stat.S_IROTH | stat.S_IXOTH)
0280         conf = SparkConf().set("spark.test.home", SPARK_HOME)
0281         conf = conf.set("spark.driver.resource.gpu.amount", "1")
0282         conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
0283         self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
0284 
0285     def test_resources(self):
0286         """Test the resources are available."""
0287         resources = self.sc.resources
0288         self.assertEqual(len(resources), 1)
0289         self.assertTrue('gpu' in resources)
0290         self.assertEqual(resources['gpu'].name, 'gpu')
0291         self.assertEqual(resources['gpu'].addresses, ['0'])
0292 
0293     def tearDown(self):
0294         os.unlink(self.tempFile.name)
0295         self.sc.stop()
0296 
0297 
0298 if __name__ == "__main__":
0299     from pyspark.tests.test_context import *
0300 
0301     try:
0302         import xmlrunner
0303         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0304     except ImportError:
0305         testRunner = None
0306     unittest.main(testRunner=testRunner, verbosity=2)