0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import re
0020 import shutil
0021 import subprocess
0022 import tempfile
0023 import unittest
0024 import zipfile
0025
0026
0027 class SparkSubmitTests(unittest.TestCase):
0028
0029 def setUp(self):
0030 self.programDir = tempfile.mkdtemp()
0031 tmp_dir = tempfile.gettempdir()
0032 self.sparkSubmit = [
0033 os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"),
0034 "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
0035 "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
0036 ]
0037
0038 def tearDown(self):
0039 shutil.rmtree(self.programDir)
0040
0041 def createTempFile(self, name, content, dir=None):
0042 """
0043 Create a temp file with the given name and content and return its path.
0044 Strips leading spaces from content up to the first '|' in each line.
0045 """
0046 pattern = re.compile(r'^ *\|', re.MULTILINE)
0047 content = re.sub(pattern, '', content.strip())
0048 if dir is None:
0049 path = os.path.join(self.programDir, name)
0050 else:
0051 os.makedirs(os.path.join(self.programDir, dir))
0052 path = os.path.join(self.programDir, dir, name)
0053 with open(path, "w") as f:
0054 f.write(content)
0055 return path
0056
0057 def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None):
0058 """
0059 Create a zip archive containing a file with the given content and return its path.
0060 Strips leading spaces from content up to the first '|' in each line.
0061 """
0062 pattern = re.compile(r'^ *\|', re.MULTILINE)
0063 content = re.sub(pattern, '', content.strip())
0064 if dir is None:
0065 path = os.path.join(self.programDir, name + ext)
0066 else:
0067 path = os.path.join(self.programDir, dir, zip_name + ext)
0068 zip = zipfile.ZipFile(path, 'w')
0069 zip.writestr(name, content)
0070 zip.close()
0071 return path
0072
0073 def create_spark_package(self, artifact_name):
0074 group_id, artifact_id, version = artifact_name.split(":")
0075 self.createTempFile("%s-%s.pom" % (artifact_id, version), ("""
0076 |<?xml version="1.0" encoding="UTF-8"?>
0077 |<project xmlns="http://maven.apache.org/POM/4.0.0"
0078 | xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
0079 | xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
0080 | http://maven.apache.org/xsd/maven-4.0.0.xsd">
0081 | <modelVersion>4.0.0</modelVersion>
0082 | <groupId>%s</groupId>
0083 | <artifactId>%s</artifactId>
0084 | <version>%s</version>
0085 |</project>
0086 """ % (group_id, artifact_id, version)).lstrip(),
0087 os.path.join(group_id, artifact_id, version))
0088 self.createFileInZip("%s.py" % artifact_id, """
0089 |def myfunc(x):
0090 | return x + 1
0091 """, ".jar", os.path.join(group_id, artifact_id, version),
0092 "%s-%s" % (artifact_id, version))
0093
0094 def test_single_script(self):
0095 """Submit and test a single script file"""
0096 script = self.createTempFile("test.py", """
0097 |from pyspark import SparkContext
0098 |
0099 |sc = SparkContext()
0100 |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
0101 """)
0102 proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
0103 out, err = proc.communicate()
0104 self.assertEqual(0, proc.returncode)
0105 self.assertIn("[2, 4, 6]", out.decode('utf-8'))
0106
0107 def test_script_with_local_functions(self):
0108 """Submit and test a single script file calling a global function"""
0109 script = self.createTempFile("test.py", """
0110 |from pyspark import SparkContext
0111 |
0112 |def foo(x):
0113 | return x * 3
0114 |
0115 |sc = SparkContext()
0116 |print(sc.parallelize([1, 2, 3]).map(foo).collect())
0117 """)
0118 proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
0119 out, err = proc.communicate()
0120 self.assertEqual(0, proc.returncode)
0121 self.assertIn("[3, 6, 9]", out.decode('utf-8'))
0122
0123 def test_module_dependency(self):
0124 """Submit and test a script with a dependency on another module"""
0125 script = self.createTempFile("test.py", """
0126 |from pyspark import SparkContext
0127 |from mylib import myfunc
0128 |
0129 |sc = SparkContext()
0130 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0131 """)
0132 zip = self.createFileInZip("mylib.py", """
0133 |def myfunc(x):
0134 | return x + 1
0135 """)
0136 proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script],
0137 stdout=subprocess.PIPE)
0138 out, err = proc.communicate()
0139 self.assertEqual(0, proc.returncode)
0140 self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0141
0142 def test_module_dependency_on_cluster(self):
0143 """Submit and test a script with a dependency on another module on a cluster"""
0144 script = self.createTempFile("test.py", """
0145 |from pyspark import SparkContext
0146 |from mylib import myfunc
0147 |
0148 |sc = SparkContext()
0149 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0150 """)
0151 zip = self.createFileInZip("mylib.py", """
0152 |def myfunc(x):
0153 | return x + 1
0154 """)
0155 proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master",
0156 "local-cluster[1,1,1024]", script],
0157 stdout=subprocess.PIPE)
0158 out, err = proc.communicate()
0159 self.assertEqual(0, proc.returncode)
0160 self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0161
0162 def test_package_dependency(self):
0163 """Submit and test a script with a dependency on a Spark Package"""
0164 script = self.createTempFile("test.py", """
0165 |from pyspark import SparkContext
0166 |from mylib import myfunc
0167 |
0168 |sc = SparkContext()
0169 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0170 """)
0171 self.create_spark_package("a:mylib:0.1")
0172 proc = subprocess.Popen(
0173 self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
0174 "file:" + self.programDir, script],
0175 stdout=subprocess.PIPE)
0176 out, err = proc.communicate()
0177 self.assertEqual(0, proc.returncode)
0178 self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0179
0180 def test_package_dependency_on_cluster(self):
0181 """Submit and test a script with a dependency on a Spark Package on a cluster"""
0182 script = self.createTempFile("test.py", """
0183 |from pyspark import SparkContext
0184 |from mylib import myfunc
0185 |
0186 |sc = SparkContext()
0187 |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
0188 """)
0189 self.create_spark_package("a:mylib:0.1")
0190 proc = subprocess.Popen(
0191 self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
0192 "file:" + self.programDir, "--master", "local-cluster[1,1,1024]",
0193 script],
0194 stdout=subprocess.PIPE)
0195 out, err = proc.communicate()
0196 self.assertEqual(0, proc.returncode)
0197 self.assertIn("[2, 3, 4]", out.decode('utf-8'))
0198
0199 def test_single_script_on_cluster(self):
0200 """Submit and test a single script on a cluster"""
0201 script = self.createTempFile("test.py", """
0202 |from pyspark import SparkContext
0203 |
0204 |def foo(x):
0205 | return x * 2
0206 |
0207 |sc = SparkContext()
0208 |print(sc.parallelize([1, 2, 3]).map(foo).collect())
0209 """)
0210
0211
0212 proc = subprocess.Popen(
0213 self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script],
0214 stdout=subprocess.PIPE)
0215 out, err = proc.communicate()
0216 self.assertEqual(0, proc.returncode)
0217 self.assertIn("[2, 4, 6]", out.decode('utf-8'))
0218
0219 def test_user_configuration(self):
0220 """Make sure user configuration is respected (SPARK-19307)"""
0221 script = self.createTempFile("test.py", """
0222 |from pyspark import SparkConf, SparkContext
0223 |
0224 |conf = SparkConf().set("spark.test_config", "1")
0225 |sc = SparkContext(conf = conf)
0226 |try:
0227 | if sc._conf.get("spark.test_config") != "1":
0228 | raise Exception("Cannot find spark.test_config in SparkContext's conf.")
0229 |finally:
0230 | sc.stop()
0231 """)
0232 proc = subprocess.Popen(
0233 self.sparkSubmit + ["--master", "local", script],
0234 stdout=subprocess.PIPE,
0235 stderr=subprocess.STDOUT)
0236 out, err = proc.communicate()
0237 self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out))
0238
0239
0240 if __name__ == "__main__":
0241 from pyspark.tests.test_appsubmit import *
0242
0243 try:
0244 import xmlrunner
0245 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0246 except ImportError:
0247 testRunner = None
0248 unittest.main(testRunner=testRunner, verbosity=2)