Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import re
0020 import shutil
0021 import subprocess
0022 import tempfile
0023 import unittest
0024 import zipfile
0025 
0026 
0027 class SparkSubmitTests(unittest.TestCase):
0028 
0029     def setUp(self):
0030         self.programDir = tempfile.mkdtemp()
0031         tmp_dir = tempfile.gettempdir()
0032         self.sparkSubmit = [
0033             os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"),
0034             "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
0035             "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
0036         ]
0037 
0038     def tearDown(self):
0039         shutil.rmtree(self.programDir)
0040 
0041     def createTempFile(self, name, content, dir=None):
0042         """
0043         Create a temp file with the given name and content and return its path.
0044         Strips leading spaces from content up to the first '|' in each line.
0045         """
0046         pattern = re.compile(r'^ *\|', re.MULTILINE)
0047         content = re.sub(pattern, '', content.strip())
0048         if dir is None:
0049             path = os.path.join(self.programDir, name)
0050         else:
0051             os.makedirs(os.path.join(self.programDir, dir))
0052             path = os.path.join(self.programDir, dir, name)
0053         with open(path, "w") as f:
0054             f.write(content)
0055         return path
0056 
0057     def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None):
0058         """
0059         Create a zip archive containing a file with the given content and return its path.
0060         Strips leading spaces from content up to the first '|' in each line.
0061         """
0062         pattern = re.compile(r'^ *\|', re.MULTILINE)
0063         content = re.sub(pattern, '', content.strip())
0064         if dir is None:
0065             path = os.path.join(self.programDir, name + ext)
0066         else:
0067             path = os.path.join(self.programDir, dir, zip_name + ext)
0068         zip = zipfile.ZipFile(path, 'w')
0069         zip.writestr(name, content)
0070         zip.close()
0071         return path
0072 
0073     def create_spark_package(self, artifact_name):
0074         group_id, artifact_id, version = artifact_name.split(":")
0075         self.createTempFile("%s-%s.pom" % (artifact_id, version), ("""
0076             |<?xml version="1.0" encoding="UTF-8"?>
0077             |<project xmlns="http://maven.apache.org/POM/4.0.0"
0078             |       xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
0079             |       xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
0080             |       http://maven.apache.org/xsd/maven-4.0.0.xsd">
0081             |   <modelVersion>4.0.0</modelVersion>
0082             |   <groupId>%s</groupId>
0083             |   <artifactId>%s</artifactId>
0084             |   <version>%s</version>
0085             |</project>
0086             """ % (group_id, artifact_id, version)).lstrip(),
0087             os.path.join(group_id, artifact_id, version))
0088         self.createFileInZip("%s.py" % artifact_id, """
0089             |def myfunc(x):
0090             |    return x + 1
0091             """, ".jar", os.path.join(group_id, artifact_id, version),
0092                              "%s-%s" % (artifact_id, version))
0093 
0094     def test_single_script(self):
0095         """Submit and test a single script file"""
0096         script = self.createTempFile("test.py", """
0097             |from pyspark import SparkContext
0098             |
0099             |sc = SparkContext()
0100             |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
0101             """)
0102         proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
0103         out, err = proc.communicate()
0104         self.assertEqual(0, proc.returncode)
0105         self.assertIn("[2, 4, 6]", out.decode('utf-8'))
0106 
0107     def test_script_with_local_functions(self):
0108         """Submit and test a single script file calling a global function"""
0109         script = self.createTempFile("test.py", """
0110             |from pyspark import SparkContext
0111             |
0112             |def foo(x):
0113             |    return x * 3
0114             |
0115             |sc = SparkContext()
0116             |print(sc.parallelize([1, 2, 3]).map(foo).collect())
0117             """)
0118         proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
0119         out, err = proc.communicate()
0120         self.assertEqual(0, proc.returncode)
0121         self.assertIn("[3, 6, 9]", out.decode('utf-8'))
0122 
0123     def test_module_dependency(self):
0124         """Submit and test a script with a dependency on another module"""
0125         script = self.createTempFile("test.py", """
0126             |from pyspark import SparkContext
0127             |from mylib import myfunc
0128             |
0129             |sc = SparkContext()
0130             |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0131             """)
0132         zip = self.createFileInZip("mylib.py", """
0133             |def myfunc(x):
0134             |    return x + 1
0135             """)
0136         proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script],
0137                                 stdout=subprocess.PIPE)
0138         out, err = proc.communicate()
0139         self.assertEqual(0, proc.returncode)
0140         self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0141 
0142     def test_module_dependency_on_cluster(self):
0143         """Submit and test a script with a dependency on another module on a cluster"""
0144         script = self.createTempFile("test.py", """
0145             |from pyspark import SparkContext
0146             |from mylib import myfunc
0147             |
0148             |sc = SparkContext()
0149             |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0150             """)
0151         zip = self.createFileInZip("mylib.py", """
0152             |def myfunc(x):
0153             |    return x + 1
0154             """)
0155         proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master",
0156                                 "local-cluster[1,1,1024]", script],
0157                                 stdout=subprocess.PIPE)
0158         out, err = proc.communicate()
0159         self.assertEqual(0, proc.returncode)
0160         self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0161 
0162     def test_package_dependency(self):
0163         """Submit and test a script with a dependency on a Spark Package"""
0164         script = self.createTempFile("test.py", """
0165             |from pyspark import SparkContext
0166             |from mylib import myfunc
0167             |
0168             |sc = SparkContext()
0169             |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0170             """)
0171         self.create_spark_package("a:mylib:0.1")
0172         proc = subprocess.Popen(
0173             self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
0174                                 "file:" + self.programDir, script],
0175             stdout=subprocess.PIPE)
0176         out, err = proc.communicate()
0177         self.assertEqual(0, proc.returncode)
0178         self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0179 
0180     def test_package_dependency_on_cluster(self):
0181         """Submit and test a script with a dependency on a Spark Package on a cluster"""
0182         script = self.createTempFile("test.py", """
0183             |from pyspark import SparkContext
0184             |from mylib import myfunc
0185             |
0186             |sc = SparkContext()
0187             |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0188             """)
0189         self.create_spark_package("a:mylib:0.1")
0190         proc = subprocess.Popen(
0191             self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
0192                                 "file:" + self.programDir, "--master", "local-cluster[1,1,1024]",
0193                                 script],
0194             stdout=subprocess.PIPE)
0195         out, err = proc.communicate()
0196         self.assertEqual(0, proc.returncode)
0197         self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0198 
0199     def test_single_script_on_cluster(self):
0200         """Submit and test a single script on a cluster"""
0201         script = self.createTempFile("test.py", """
0202             |from pyspark import SparkContext
0203             |
0204             |def foo(x):
0205             |    return x * 2
0206             |
0207             |sc = SparkContext()
0208             |print(sc.parallelize([1, 2, 3]).map(foo).collect())
0209             """)
0210         # this will fail if you have different spark.executor.memory
0211         # in conf/spark-defaults.conf
0212         proc = subprocess.Popen(
0213             self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script],
0214             stdout=subprocess.PIPE)
0215         out, err = proc.communicate()
0216         self.assertEqual(0, proc.returncode)
0217         self.assertIn("[2, 4, 6]", out.decode('utf-8'))
0218 
0219     def test_user_configuration(self):
0220         """Make sure user configuration is respected (SPARK-19307)"""
0221         script = self.createTempFile("test.py", """
0222             |from pyspark import SparkConf, SparkContext
0223             |
0224             |conf = SparkConf().set("spark.test_config", "1")
0225             |sc = SparkContext(conf = conf)
0226             |try:
0227             |    if sc._conf.get("spark.test_config") != "1":
0228             |        raise Exception("Cannot find spark.test_config in SparkContext's conf.")
0229             |finally:
0230             |    sc.stop()
0231             """)
0232         proc = subprocess.Popen(
0233             self.sparkSubmit + ["--master", "local", script],
0234             stdout=subprocess.PIPE,
0235             stderr=subprocess.STDOUT)
0236         out, err = proc.communicate()
0237         self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out))
0238 
0239 
0240 if __name__ == "__main__":
0241     from pyspark.tests.test_appsubmit import *
0242 
0243     try:
0244         import xmlrunner
0245         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0246     except ImportError:
0247         testRunner = None
0248     unittest.main(testRunner=testRunner, verbosity=2)