Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import datetime
0019 import os
0020 import threading
0021 import time
0022 import unittest
0023 import warnings
0024 
0025 from pyspark import SparkContext, SparkConf
0026 from pyspark.sql import Row, SparkSession
0027 from pyspark.sql.functions import udf
0028 from pyspark.sql.types import *
0029 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0030     pandas_requirement_message, pyarrow_requirement_message
0031 from pyspark.testing.utils import QuietTest
0032 from pyspark.util import _exception_message
0033 
0034 if have_pandas:
0035     import pandas as pd
0036     from pandas.util.testing import assert_frame_equal
0037 
0038 if have_pyarrow:
0039     import pyarrow as pa
0040 
0041 
0042 @unittest.skipIf(
0043     not have_pandas or not have_pyarrow,
0044     pandas_requirement_message or pyarrow_requirement_message)
0045 class ArrowTests(ReusedSQLTestCase):
0046 
0047     @classmethod
0048     def setUpClass(cls):
0049         from datetime import date, datetime
0050         from decimal import Decimal
0051         super(ArrowTests, cls).setUpClass()
0052         cls.warnings_lock = threading.Lock()
0053 
0054         # Synchronize default timezone between Python and Java
0055         cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
0056         tz = "America/Los_Angeles"
0057         os.environ["TZ"] = tz
0058         time.tzset()
0059 
0060         cls.spark.conf.set("spark.sql.session.timeZone", tz)
0061 
0062         # Test fallback
0063         cls.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
0064         assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.enabled") == "false"
0065         cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
0066         assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.enabled") == "true"
0067 
0068         cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "true")
0069         assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.fallback.enabled") == "true"
0070         cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
0071         assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.fallback.enabled") == "false"
0072 
0073         # Enable Arrow optimization in this tests.
0074         cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
0075         # Disable fallback by default to easily detect the failures.
0076         cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
0077 
0078         cls.schema = StructType([
0079             StructField("1_str_t", StringType(), True),
0080             StructField("2_int_t", IntegerType(), True),
0081             StructField("3_long_t", LongType(), True),
0082             StructField("4_float_t", FloatType(), True),
0083             StructField("5_double_t", DoubleType(), True),
0084             StructField("6_decimal_t", DecimalType(38, 18), True),
0085             StructField("7_date_t", DateType(), True),
0086             StructField("8_timestamp_t", TimestampType(), True),
0087             StructField("9_binary_t", BinaryType(), True)])
0088         cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
0089                      date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), bytearray(b"a")),
0090                     (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
0091                      date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2), bytearray(b"bb")),
0092                     (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
0093                      date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3), bytearray(b"ccc")),
0094                     (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
0095                      date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3), bytearray(b"dddd"))]
0096 
0097     @classmethod
0098     def tearDownClass(cls):
0099         del os.environ["TZ"]
0100         if cls.tz_prev is not None:
0101             os.environ["TZ"] = cls.tz_prev
0102         time.tzset()
0103         super(ArrowTests, cls).tearDownClass()
0104 
0105     def create_pandas_data_frame(self):
0106         import numpy as np
0107         data_dict = {}
0108         for j, name in enumerate(self.schema.names):
0109             data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
0110         # need to convert these to numpy types first
0111         data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
0112         data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
0113         return pd.DataFrame(data=data_dict)
0114 
0115     def test_toPandas_fallback_enabled(self):
0116         with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
0117             schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
0118             df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
0119             with QuietTest(self.sc):
0120                 with self.warnings_lock:
0121                     with warnings.catch_warnings(record=True) as warns:
0122                         # we want the warnings to appear even if this test is run from a subclass
0123                         warnings.simplefilter("always")
0124                         pdf = df.toPandas()
0125                         # Catch and check the last UserWarning.
0126                         user_warns = [
0127                             warn.message for warn in warns if isinstance(warn.message, UserWarning)]
0128                         self.assertTrue(len(user_warns) > 0)
0129                         self.assertTrue(
0130                             "Attempting non-optimization" in _exception_message(user_warns[-1]))
0131                         assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
0132 
0133     def test_toPandas_fallback_disabled(self):
0134         schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
0135         df = self.spark.createDataFrame([(None,)], schema=schema)
0136         with QuietTest(self.sc):
0137             with self.warnings_lock:
0138                 with self.assertRaisesRegexp(Exception, 'Unsupported type'):
0139                     df.toPandas()
0140 
0141     def test_null_conversion(self):
0142         df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
0143                                              self.data)
0144         pdf = df_null.toPandas()
0145         null_counts = pdf.isnull().sum().tolist()
0146         self.assertTrue(all([c == 1 for c in null_counts]))
0147 
0148     def _toPandas_arrow_toggle(self, df):
0149         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0150             pdf = df.toPandas()
0151 
0152         pdf_arrow = df.toPandas()
0153 
0154         return pdf, pdf_arrow
0155 
0156     def test_toPandas_arrow_toggle(self):
0157         df = self.spark.createDataFrame(self.data, schema=self.schema)
0158         pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
0159         expected = self.create_pandas_data_frame()
0160         assert_frame_equal(expected, pdf)
0161         assert_frame_equal(expected, pdf_arrow)
0162 
0163     def test_toPandas_respect_session_timezone(self):
0164         df = self.spark.createDataFrame(self.data, schema=self.schema)
0165 
0166         timezone = "America/Los_Angeles"
0167         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0168             pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
0169             assert_frame_equal(pdf_arrow_la, pdf_la)
0170 
0171         timezone = "America/New_York"
0172         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0173             pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
0174             assert_frame_equal(pdf_arrow_ny, pdf_ny)
0175 
0176             self.assertFalse(pdf_ny.equals(pdf_la))
0177 
0178             from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
0179             pdf_la_corrected = pdf_la.copy()
0180             for field in self.schema:
0181                 if isinstance(field.dataType, TimestampType):
0182                     pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
0183                         pdf_la_corrected[field.name], timezone)
0184             assert_frame_equal(pdf_ny, pdf_la_corrected)
0185 
0186     def test_pandas_round_trip(self):
0187         pdf = self.create_pandas_data_frame()
0188         df = self.spark.createDataFrame(self.data, schema=self.schema)
0189         pdf_arrow = df.toPandas()
0190         assert_frame_equal(pdf_arrow, pdf)
0191 
0192     def test_filtered_frame(self):
0193         df = self.spark.range(3).toDF("i")
0194         pdf = df.filter("i < 0").toPandas()
0195         self.assertEqual(len(pdf.columns), 1)
0196         self.assertEqual(pdf.columns[0], "i")
0197         self.assertTrue(pdf.empty)
0198 
0199     def test_no_partition_frame(self):
0200         schema = StructType([StructField("field1", StringType(), True)])
0201         df = self.spark.createDataFrame(self.sc.emptyRDD(), schema)
0202         pdf = df.toPandas()
0203         self.assertEqual(len(pdf.columns), 1)
0204         self.assertEqual(pdf.columns[0], "field1")
0205         self.assertTrue(pdf.empty)
0206 
0207     def test_propagates_spark_exception(self):
0208         df = self.spark.range(3).toDF("i")
0209 
0210         def raise_exception():
0211             raise Exception("My error")
0212         exception_udf = udf(raise_exception, IntegerType())
0213         df = df.withColumn("error", exception_udf())
0214         with QuietTest(self.sc):
0215             with self.assertRaisesRegexp(Exception, 'My error'):
0216                 df.toPandas()
0217 
0218     def _createDataFrame_toggle(self, pdf, schema=None):
0219         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0220             df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
0221 
0222         df_arrow = self.spark.createDataFrame(pdf, schema=schema)
0223 
0224         return df_no_arrow, df_arrow
0225 
0226     def test_createDataFrame_toggle(self):
0227         pdf = self.create_pandas_data_frame()
0228         df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
0229         self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
0230 
0231     def test_createDataFrame_respect_session_timezone(self):
0232         from datetime import timedelta
0233         pdf = self.create_pandas_data_frame()
0234         timezone = "America/Los_Angeles"
0235         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0236             df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
0237             result_la = df_no_arrow_la.collect()
0238             result_arrow_la = df_arrow_la.collect()
0239             self.assertEqual(result_la, result_arrow_la)
0240 
0241         timezone = "America/New_York"
0242         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0243             df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
0244             result_ny = df_no_arrow_ny.collect()
0245             result_arrow_ny = df_arrow_ny.collect()
0246             self.assertEqual(result_ny, result_arrow_ny)
0247 
0248             self.assertNotEqual(result_ny, result_la)
0249 
0250             # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
0251             result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v
0252                                           for k, v in row.asDict().items()})
0253                                    for row in result_la]
0254             self.assertEqual(result_ny, result_la_corrected)
0255 
0256     def test_createDataFrame_with_schema(self):
0257         pdf = self.create_pandas_data_frame()
0258         df = self.spark.createDataFrame(pdf, schema=self.schema)
0259         self.assertEquals(self.schema, df.schema)
0260         pdf_arrow = df.toPandas()
0261         assert_frame_equal(pdf_arrow, pdf)
0262 
0263     def test_createDataFrame_with_incorrect_schema(self):
0264         pdf = self.create_pandas_data_frame()
0265         fields = list(self.schema)
0266         fields[0], fields[1] = fields[1], fields[0]  # swap str with int
0267         wrong_schema = StructType(fields)
0268         with QuietTest(self.sc):
0269             with self.assertRaisesRegexp(Exception, "integer.*required"):
0270                 self.spark.createDataFrame(pdf, schema=wrong_schema)
0271 
0272     def test_createDataFrame_with_names(self):
0273         pdf = self.create_pandas_data_frame()
0274         new_names = list(map(str, range(len(self.schema.fieldNames()))))
0275         # Test that schema as a list of column names gets applied
0276         df = self.spark.createDataFrame(pdf, schema=list(new_names))
0277         self.assertEquals(df.schema.fieldNames(), new_names)
0278         # Test that schema as tuple of column names gets applied
0279         df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
0280         self.assertEquals(df.schema.fieldNames(), new_names)
0281 
0282     def test_createDataFrame_column_name_encoding(self):
0283         pdf = pd.DataFrame({u'a': [1]})
0284         columns = self.spark.createDataFrame(pdf).columns
0285         self.assertTrue(isinstance(columns[0], str))
0286         self.assertEquals(columns[0], 'a')
0287         columns = self.spark.createDataFrame(pdf, [u'b']).columns
0288         self.assertTrue(isinstance(columns[0], str))
0289         self.assertEquals(columns[0], 'b')
0290 
0291     def test_createDataFrame_with_single_data_type(self):
0292         with QuietTest(self.sc):
0293             with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
0294                 self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
0295 
0296     def test_createDataFrame_does_not_modify_input(self):
0297         # Some series get converted for Spark to consume, this makes sure input is unchanged
0298         pdf = self.create_pandas_data_frame()
0299         # Use a nanosecond value to make sure it is not truncated
0300         pdf.iloc[0, 7] = pd.Timestamp(1)
0301         # Integers with nulls will get NaNs filled with 0 and will be casted
0302         pdf.iloc[1, 1] = None
0303         pdf_copy = pdf.copy(deep=True)
0304         self.spark.createDataFrame(pdf, schema=self.schema)
0305         self.assertTrue(pdf.equals(pdf_copy))
0306 
0307     def test_schema_conversion_roundtrip(self):
0308         from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
0309         arrow_schema = to_arrow_schema(self.schema)
0310         schema_rt = from_arrow_schema(arrow_schema)
0311         self.assertEquals(self.schema, schema_rt)
0312 
0313     def test_createDataFrame_with_array_type(self):
0314         pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
0315         df, df_arrow = self._createDataFrame_toggle(pdf)
0316         result = df.collect()
0317         result_arrow = df_arrow.collect()
0318         expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
0319         for r in range(len(expected)):
0320             for e in range(len(expected[r])):
0321                 self.assertTrue(expected[r][e] == result_arrow[r][e] and
0322                                 result[r][e] == result_arrow[r][e])
0323 
0324     def test_toPandas_with_array_type(self):
0325         expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
0326         array_schema = StructType([StructField("a", ArrayType(IntegerType())),
0327                                    StructField("b", ArrayType(StringType()))])
0328         df = self.spark.createDataFrame(expected, schema=array_schema)
0329         pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
0330         result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
0331         result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
0332         for r in range(len(expected)):
0333             for e in range(len(expected[r])):
0334                 self.assertTrue(expected[r][e] == result_arrow[r][e] and
0335                                 result[r][e] == result_arrow[r][e])
0336 
0337     def test_createDataFrame_with_int_col_names(self):
0338         import numpy as np
0339         pdf = pd.DataFrame(np.random.rand(4, 2))
0340         df, df_arrow = self._createDataFrame_toggle(pdf)
0341         pdf_col_names = [str(c) for c in pdf.columns]
0342         self.assertEqual(pdf_col_names, df.columns)
0343         self.assertEqual(pdf_col_names, df_arrow.columns)
0344 
0345     def test_createDataFrame_fallback_enabled(self):
0346         with QuietTest(self.sc):
0347             with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
0348                 with warnings.catch_warnings(record=True) as warns:
0349                     # we want the warnings to appear even if this test is run from a subclass
0350                     warnings.simplefilter("always")
0351                     df = self.spark.createDataFrame(
0352                         pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
0353                     # Catch and check the last UserWarning.
0354                     user_warns = [
0355                         warn.message for warn in warns if isinstance(warn.message, UserWarning)]
0356                     self.assertTrue(len(user_warns) > 0)
0357                     self.assertTrue(
0358                         "Attempting non-optimization" in _exception_message(user_warns[-1]))
0359                     self.assertEqual(df.collect(), [Row(a={u'a': 1})])
0360 
0361     def test_createDataFrame_fallback_disabled(self):
0362         with QuietTest(self.sc):
0363             with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
0364                 self.spark.createDataFrame(
0365                     pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
0366 
0367     # Regression test for SPARK-23314
0368     def test_timestamp_dst(self):
0369         # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
0370         dt = [datetime.datetime(2015, 11, 1, 0, 30),
0371               datetime.datetime(2015, 11, 1, 1, 30),
0372               datetime.datetime(2015, 11, 1, 2, 30)]
0373         pdf = pd.DataFrame({'time': dt})
0374 
0375         df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
0376         df_from_pandas = self.spark.createDataFrame(pdf)
0377 
0378         assert_frame_equal(pdf, df_from_python.toPandas())
0379         assert_frame_equal(pdf, df_from_pandas.toPandas())
0380 
0381     # Regression test for SPARK-28003
0382     def test_timestamp_nat(self):
0383         dt = [pd.NaT, pd.Timestamp('2019-06-11'), None] * 100
0384         pdf = pd.DataFrame({'time': dt})
0385         df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf)
0386 
0387         assert_frame_equal(pdf, df_no_arrow.toPandas())
0388         assert_frame_equal(pdf, df_arrow.toPandas())
0389 
0390     def test_toPandas_batch_order(self):
0391 
0392         def delay_first_part(partition_index, iterator):
0393             if partition_index == 0:
0394                 time.sleep(0.1)
0395             return iterator
0396 
0397         # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python
0398         def run_test(num_records, num_parts, max_records, use_delay=False):
0399             df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
0400             if use_delay:
0401                 df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
0402             with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
0403                 pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
0404                 assert_frame_equal(pdf, pdf_arrow)
0405 
0406         cases = [
0407             (1024, 512, 2),    # Use large num partitions for more likely collecting out of order
0408             (64, 8, 2, True),  # Use delay in first partition to force collecting out of order
0409             (64, 64, 1),       # Test single batch per partition
0410             (64, 1, 64),       # Test single partition, single batch
0411             (64, 1, 8),        # Test single partition, multiple batches
0412             (30, 7, 2),        # Test different sized partitions
0413         ]
0414 
0415         for case in cases:
0416             run_test(*case)
0417 
0418 
0419 @unittest.skipIf(
0420     not have_pandas or not have_pyarrow,
0421     pandas_requirement_message or pyarrow_requirement_message)
0422 class MaxResultArrowTests(unittest.TestCase):
0423     # These tests are separate as 'spark.driver.maxResultSize' configuration
0424     # is a static configuration to Spark context.
0425 
0426     @classmethod
0427     def setUpClass(cls):
0428         cls.spark = SparkSession(SparkContext(
0429             'local[4]', cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k")))
0430 
0431         # Explicitly enable Arrow and disable fallback.
0432         cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
0433         cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
0434 
0435     @classmethod
0436     def tearDownClass(cls):
0437         if hasattr(cls, "spark"):
0438             cls.spark.stop()
0439 
0440     def test_exception_by_max_results(self):
0441         with self.assertRaisesRegexp(Exception, "is bigger than"):
0442             self.spark.range(0, 10000, 1, 100).toPandas()
0443 
0444 
0445 class EncryptionArrowTests(ArrowTests):
0446 
0447     @classmethod
0448     def conf(cls):
0449         return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
0450 
0451 
0452 if __name__ == "__main__":
0453     from pyspark.sql.tests.test_arrow import *
0454 
0455     try:
0456         import xmlrunner
0457         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0458     except ImportError:
0459         testRunner = None
0460     unittest.main(testRunner=testRunner, verbosity=2)