Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import glob
0018 import os
0019 import struct
0020 import sys
0021 import unittest
0022 from time import time, sleep
0023 
0024 from pyspark import SparkContext, SparkConf
0025 
0026 
0027 have_scipy = False
0028 have_numpy = False
0029 try:
0030     import scipy.sparse
0031     have_scipy = True
0032 except:
0033     # No SciPy, but that's okay, we'll skip those tests
0034     pass
0035 try:
0036     import numpy as np
0037     have_numpy = True
0038 except:
0039     # No NumPy, but that's okay, we'll skip those tests
0040     pass
0041 
0042 
0043 SPARK_HOME = os.environ["SPARK_HOME"]
0044 
0045 
0046 def read_int(b):
0047     return struct.unpack("!i", b)[0]
0048 
0049 
0050 def write_int(i):
0051     return struct.pack("!i", i)
0052 
0053 
0054 def eventually(condition, timeout=30.0, catch_assertions=False):
0055     """
0056     Wait a given amount of time for a condition to pass, else fail with an error.
0057     This is a helper utility for PySpark tests.
0058 
0059     :param condition: Function that checks for termination conditions.
0060                       condition() can return:
0061                        - True: Conditions met. Return without error.
0062                        - other value: Conditions not met yet. Continue. Upon timeout,
0063                                       include last such value in error message.
0064                       Note that this method may be called at any time during
0065                       streaming execution (e.g., even before any results
0066                       have been created).
0067     :param timeout: Number of seconds to wait.  Default 30 seconds.
0068     :param catch_assertions: If False (default), do not catch AssertionErrors.
0069                              If True, catch AssertionErrors; continue, but save
0070                              error to throw upon timeout.
0071     """
0072     start_time = time()
0073     lastValue = None
0074     while time() - start_time < timeout:
0075         if catch_assertions:
0076             try:
0077                 lastValue = condition()
0078             except AssertionError as e:
0079                 lastValue = e
0080         else:
0081             lastValue = condition()
0082         if lastValue is True:
0083             return
0084         sleep(0.01)
0085     if isinstance(lastValue, AssertionError):
0086         raise lastValue
0087     else:
0088         raise AssertionError(
0089             "Test failed due to timeout after %g sec, with last condition returning: %s"
0090             % (timeout, lastValue))
0091 
0092 
0093 class QuietTest(object):
0094     def __init__(self, sc):
0095         self.log4j = sc._jvm.org.apache.log4j
0096 
0097     def __enter__(self):
0098         self.old_level = self.log4j.LogManager.getRootLogger().getLevel()
0099         self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)
0100 
0101     def __exit__(self, exc_type, exc_val, exc_tb):
0102         self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
0103 
0104 
0105 class PySparkTestCase(unittest.TestCase):
0106 
0107     def setUp(self):
0108         self._old_sys_path = list(sys.path)
0109         class_name = self.__class__.__name__
0110         self.sc = SparkContext('local[4]', class_name)
0111 
0112     def tearDown(self):
0113         self.sc.stop()
0114         sys.path = self._old_sys_path
0115 
0116 
0117 class ReusedPySparkTestCase(unittest.TestCase):
0118 
0119     @classmethod
0120     def conf(cls):
0121         """
0122         Override this in subclasses to supply a more specific conf
0123         """
0124         return SparkConf()
0125 
0126     @classmethod
0127     def setUpClass(cls):
0128         cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf())
0129 
0130     @classmethod
0131     def tearDownClass(cls):
0132         cls.sc.stop()
0133 
0134 
0135 class ByteArrayOutput(object):
0136     def __init__(self):
0137         self.buffer = bytearray()
0138 
0139     def write(self, b):
0140         self.buffer += b
0141 
0142     def close(self):
0143         pass
0144 
0145 
0146 def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
0147     # Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can
0148     # vary for SBT or Maven specifically. See also SPARK-26856
0149     project_full_path = os.path.join(
0150         os.environ["SPARK_HOME"], project_relative_path)
0151 
0152     # We should ignore the following jars
0153     ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")
0154 
0155     # Search jar in the project dir using the jar name_prefix for both sbt build and maven
0156     # build because the artifact jars are in different directories.
0157     sbt_build = glob.glob(os.path.join(
0158         project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix))
0159     maven_build = glob.glob(os.path.join(
0160         project_full_path, "target/%s*.jar" % mvn_jar_name_prefix))
0161     jar_paths = sbt_build + maven_build
0162     jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)]
0163 
0164     if not jars:
0165         return None
0166     elif len(jars) > 1:
0167         raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
0168     else:
0169         return jars[0]