0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import datetime
0019 import os
0020 import threading
0021 import time
0022 import unittest
0023 import warnings
0024
0025 from pyspark import SparkContext, SparkConf
0026 from pyspark.sql import Row, SparkSession
0027 from pyspark.sql.functions import udf
0028 from pyspark.sql.types import *
0029 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0030 pandas_requirement_message, pyarrow_requirement_message
0031 from pyspark.testing.utils import QuietTest
0032 from pyspark.util import _exception_message
0033
0034 if have_pandas:
0035 import pandas as pd
0036 from pandas.util.testing import assert_frame_equal
0037
0038 if have_pyarrow:
0039 import pyarrow as pa
0040
0041
0042 @unittest.skipIf(
0043 not have_pandas or not have_pyarrow,
0044 pandas_requirement_message or pyarrow_requirement_message)
0045 class ArrowTests(ReusedSQLTestCase):
0046
0047 @classmethod
0048 def setUpClass(cls):
0049 from datetime import date, datetime
0050 from decimal import Decimal
0051 super(ArrowTests, cls).setUpClass()
0052 cls.warnings_lock = threading.Lock()
0053
0054
0055 cls.tz_prev = os.environ.get("TZ", None)
0056 tz = "America/Los_Angeles"
0057 os.environ["TZ"] = tz
0058 time.tzset()
0059
0060 cls.spark.conf.set("spark.sql.session.timeZone", tz)
0061
0062
0063 cls.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
0064 assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.enabled") == "false"
0065 cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
0066 assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.enabled") == "true"
0067
0068 cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "true")
0069 assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.fallback.enabled") == "true"
0070 cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
0071 assert cls.spark.conf.get("spark.sql.execution.arrow.pyspark.fallback.enabled") == "false"
0072
0073
0074 cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
0075
0076 cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
0077
0078 cls.schema = StructType([
0079 StructField("1_str_t", StringType(), True),
0080 StructField("2_int_t", IntegerType(), True),
0081 StructField("3_long_t", LongType(), True),
0082 StructField("4_float_t", FloatType(), True),
0083 StructField("5_double_t", DoubleType(), True),
0084 StructField("6_decimal_t", DecimalType(38, 18), True),
0085 StructField("7_date_t", DateType(), True),
0086 StructField("8_timestamp_t", TimestampType(), True),
0087 StructField("9_binary_t", BinaryType(), True)])
0088 cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
0089 date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), bytearray(b"a")),
0090 (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
0091 date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2), bytearray(b"bb")),
0092 (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
0093 date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3), bytearray(b"ccc")),
0094 (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
0095 date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3), bytearray(b"dddd"))]
0096
0097 @classmethod
0098 def tearDownClass(cls):
0099 del os.environ["TZ"]
0100 if cls.tz_prev is not None:
0101 os.environ["TZ"] = cls.tz_prev
0102 time.tzset()
0103 super(ArrowTests, cls).tearDownClass()
0104
0105 def create_pandas_data_frame(self):
0106 import numpy as np
0107 data_dict = {}
0108 for j, name in enumerate(self.schema.names):
0109 data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
0110
0111 data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
0112 data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
0113 return pd.DataFrame(data=data_dict)
0114
0115 def test_toPandas_fallback_enabled(self):
0116 with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
0117 schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
0118 df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
0119 with QuietTest(self.sc):
0120 with self.warnings_lock:
0121 with warnings.catch_warnings(record=True) as warns:
0122
0123 warnings.simplefilter("always")
0124 pdf = df.toPandas()
0125
0126 user_warns = [
0127 warn.message for warn in warns if isinstance(warn.message, UserWarning)]
0128 self.assertTrue(len(user_warns) > 0)
0129 self.assertTrue(
0130 "Attempting non-optimization" in _exception_message(user_warns[-1]))
0131 assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
0132
0133 def test_toPandas_fallback_disabled(self):
0134 schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
0135 df = self.spark.createDataFrame([(None,)], schema=schema)
0136 with QuietTest(self.sc):
0137 with self.warnings_lock:
0138 with self.assertRaisesRegexp(Exception, 'Unsupported type'):
0139 df.toPandas()
0140
0141 def test_null_conversion(self):
0142 df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
0143 self.data)
0144 pdf = df_null.toPandas()
0145 null_counts = pdf.isnull().sum().tolist()
0146 self.assertTrue(all([c == 1 for c in null_counts]))
0147
0148 def _toPandas_arrow_toggle(self, df):
0149 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0150 pdf = df.toPandas()
0151
0152 pdf_arrow = df.toPandas()
0153
0154 return pdf, pdf_arrow
0155
0156 def test_toPandas_arrow_toggle(self):
0157 df = self.spark.createDataFrame(self.data, schema=self.schema)
0158 pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
0159 expected = self.create_pandas_data_frame()
0160 assert_frame_equal(expected, pdf)
0161 assert_frame_equal(expected, pdf_arrow)
0162
0163 def test_toPandas_respect_session_timezone(self):
0164 df = self.spark.createDataFrame(self.data, schema=self.schema)
0165
0166 timezone = "America/Los_Angeles"
0167 with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0168 pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
0169 assert_frame_equal(pdf_arrow_la, pdf_la)
0170
0171 timezone = "America/New_York"
0172 with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0173 pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
0174 assert_frame_equal(pdf_arrow_ny, pdf_ny)
0175
0176 self.assertFalse(pdf_ny.equals(pdf_la))
0177
0178 from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
0179 pdf_la_corrected = pdf_la.copy()
0180 for field in self.schema:
0181 if isinstance(field.dataType, TimestampType):
0182 pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
0183 pdf_la_corrected[field.name], timezone)
0184 assert_frame_equal(pdf_ny, pdf_la_corrected)
0185
0186 def test_pandas_round_trip(self):
0187 pdf = self.create_pandas_data_frame()
0188 df = self.spark.createDataFrame(self.data, schema=self.schema)
0189 pdf_arrow = df.toPandas()
0190 assert_frame_equal(pdf_arrow, pdf)
0191
0192 def test_filtered_frame(self):
0193 df = self.spark.range(3).toDF("i")
0194 pdf = df.filter("i < 0").toPandas()
0195 self.assertEqual(len(pdf.columns), 1)
0196 self.assertEqual(pdf.columns[0], "i")
0197 self.assertTrue(pdf.empty)
0198
0199 def test_no_partition_frame(self):
0200 schema = StructType([StructField("field1", StringType(), True)])
0201 df = self.spark.createDataFrame(self.sc.emptyRDD(), schema)
0202 pdf = df.toPandas()
0203 self.assertEqual(len(pdf.columns), 1)
0204 self.assertEqual(pdf.columns[0], "field1")
0205 self.assertTrue(pdf.empty)
0206
0207 def test_propagates_spark_exception(self):
0208 df = self.spark.range(3).toDF("i")
0209
0210 def raise_exception():
0211 raise Exception("My error")
0212 exception_udf = udf(raise_exception, IntegerType())
0213 df = df.withColumn("error", exception_udf())
0214 with QuietTest(self.sc):
0215 with self.assertRaisesRegexp(Exception, 'My error'):
0216 df.toPandas()
0217
0218 def _createDataFrame_toggle(self, pdf, schema=None):
0219 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0220 df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
0221
0222 df_arrow = self.spark.createDataFrame(pdf, schema=schema)
0223
0224 return df_no_arrow, df_arrow
0225
0226 def test_createDataFrame_toggle(self):
0227 pdf = self.create_pandas_data_frame()
0228 df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
0229 self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
0230
0231 def test_createDataFrame_respect_session_timezone(self):
0232 from datetime import timedelta
0233 pdf = self.create_pandas_data_frame()
0234 timezone = "America/Los_Angeles"
0235 with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0236 df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
0237 result_la = df_no_arrow_la.collect()
0238 result_arrow_la = df_arrow_la.collect()
0239 self.assertEqual(result_la, result_arrow_la)
0240
0241 timezone = "America/New_York"
0242 with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0243 df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
0244 result_ny = df_no_arrow_ny.collect()
0245 result_arrow_ny = df_arrow_ny.collect()
0246 self.assertEqual(result_ny, result_arrow_ny)
0247
0248 self.assertNotEqual(result_ny, result_la)
0249
0250
0251 result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v
0252 for k, v in row.asDict().items()})
0253 for row in result_la]
0254 self.assertEqual(result_ny, result_la_corrected)
0255
0256 def test_createDataFrame_with_schema(self):
0257 pdf = self.create_pandas_data_frame()
0258 df = self.spark.createDataFrame(pdf, schema=self.schema)
0259 self.assertEquals(self.schema, df.schema)
0260 pdf_arrow = df.toPandas()
0261 assert_frame_equal(pdf_arrow, pdf)
0262
0263 def test_createDataFrame_with_incorrect_schema(self):
0264 pdf = self.create_pandas_data_frame()
0265 fields = list(self.schema)
0266 fields[0], fields[1] = fields[1], fields[0]
0267 wrong_schema = StructType(fields)
0268 with QuietTest(self.sc):
0269 with self.assertRaisesRegexp(Exception, "integer.*required"):
0270 self.spark.createDataFrame(pdf, schema=wrong_schema)
0271
0272 def test_createDataFrame_with_names(self):
0273 pdf = self.create_pandas_data_frame()
0274 new_names = list(map(str, range(len(self.schema.fieldNames()))))
0275
0276 df = self.spark.createDataFrame(pdf, schema=list(new_names))
0277 self.assertEquals(df.schema.fieldNames(), new_names)
0278
0279 df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
0280 self.assertEquals(df.schema.fieldNames(), new_names)
0281
0282 def test_createDataFrame_column_name_encoding(self):
0283 pdf = pd.DataFrame({u'a': [1]})
0284 columns = self.spark.createDataFrame(pdf).columns
0285 self.assertTrue(isinstance(columns[0], str))
0286 self.assertEquals(columns[0], 'a')
0287 columns = self.spark.createDataFrame(pdf, [u'b']).columns
0288 self.assertTrue(isinstance(columns[0], str))
0289 self.assertEquals(columns[0], 'b')
0290
0291 def test_createDataFrame_with_single_data_type(self):
0292 with QuietTest(self.sc):
0293 with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
0294 self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
0295
0296 def test_createDataFrame_does_not_modify_input(self):
0297
0298 pdf = self.create_pandas_data_frame()
0299
0300 pdf.iloc[0, 7] = pd.Timestamp(1)
0301
0302 pdf.iloc[1, 1] = None
0303 pdf_copy = pdf.copy(deep=True)
0304 self.spark.createDataFrame(pdf, schema=self.schema)
0305 self.assertTrue(pdf.equals(pdf_copy))
0306
0307 def test_schema_conversion_roundtrip(self):
0308 from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
0309 arrow_schema = to_arrow_schema(self.schema)
0310 schema_rt = from_arrow_schema(arrow_schema)
0311 self.assertEquals(self.schema, schema_rt)
0312
0313 def test_createDataFrame_with_array_type(self):
0314 pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
0315 df, df_arrow = self._createDataFrame_toggle(pdf)
0316 result = df.collect()
0317 result_arrow = df_arrow.collect()
0318 expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
0319 for r in range(len(expected)):
0320 for e in range(len(expected[r])):
0321 self.assertTrue(expected[r][e] == result_arrow[r][e] and
0322 result[r][e] == result_arrow[r][e])
0323
0324 def test_toPandas_with_array_type(self):
0325 expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
0326 array_schema = StructType([StructField("a", ArrayType(IntegerType())),
0327 StructField("b", ArrayType(StringType()))])
0328 df = self.spark.createDataFrame(expected, schema=array_schema)
0329 pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
0330 result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
0331 result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
0332 for r in range(len(expected)):
0333 for e in range(len(expected[r])):
0334 self.assertTrue(expected[r][e] == result_arrow[r][e] and
0335 result[r][e] == result_arrow[r][e])
0336
0337 def test_createDataFrame_with_int_col_names(self):
0338 import numpy as np
0339 pdf = pd.DataFrame(np.random.rand(4, 2))
0340 df, df_arrow = self._createDataFrame_toggle(pdf)
0341 pdf_col_names = [str(c) for c in pdf.columns]
0342 self.assertEqual(pdf_col_names, df.columns)
0343 self.assertEqual(pdf_col_names, df_arrow.columns)
0344
0345 def test_createDataFrame_fallback_enabled(self):
0346 with QuietTest(self.sc):
0347 with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
0348 with warnings.catch_warnings(record=True) as warns:
0349
0350 warnings.simplefilter("always")
0351 df = self.spark.createDataFrame(
0352 pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
0353
0354 user_warns = [
0355 warn.message for warn in warns if isinstance(warn.message, UserWarning)]
0356 self.assertTrue(len(user_warns) > 0)
0357 self.assertTrue(
0358 "Attempting non-optimization" in _exception_message(user_warns[-1]))
0359 self.assertEqual(df.collect(), [Row(a={u'a': 1})])
0360
0361 def test_createDataFrame_fallback_disabled(self):
0362 with QuietTest(self.sc):
0363 with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
0364 self.spark.createDataFrame(
0365 pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
0366
0367
0368 def test_timestamp_dst(self):
0369
0370 dt = [datetime.datetime(2015, 11, 1, 0, 30),
0371 datetime.datetime(2015, 11, 1, 1, 30),
0372 datetime.datetime(2015, 11, 1, 2, 30)]
0373 pdf = pd.DataFrame({'time': dt})
0374
0375 df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
0376 df_from_pandas = self.spark.createDataFrame(pdf)
0377
0378 assert_frame_equal(pdf, df_from_python.toPandas())
0379 assert_frame_equal(pdf, df_from_pandas.toPandas())
0380
0381
0382 def test_timestamp_nat(self):
0383 dt = [pd.NaT, pd.Timestamp('2019-06-11'), None] * 100
0384 pdf = pd.DataFrame({'time': dt})
0385 df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf)
0386
0387 assert_frame_equal(pdf, df_no_arrow.toPandas())
0388 assert_frame_equal(pdf, df_arrow.toPandas())
0389
0390 def test_toPandas_batch_order(self):
0391
0392 def delay_first_part(partition_index, iterator):
0393 if partition_index == 0:
0394 time.sleep(0.1)
0395 return iterator
0396
0397
0398 def run_test(num_records, num_parts, max_records, use_delay=False):
0399 df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
0400 if use_delay:
0401 df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
0402 with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
0403 pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
0404 assert_frame_equal(pdf, pdf_arrow)
0405
0406 cases = [
0407 (1024, 512, 2),
0408 (64, 8, 2, True),
0409 (64, 64, 1),
0410 (64, 1, 64),
0411 (64, 1, 8),
0412 (30, 7, 2),
0413 ]
0414
0415 for case in cases:
0416 run_test(*case)
0417
0418
0419 @unittest.skipIf(
0420 not have_pandas or not have_pyarrow,
0421 pandas_requirement_message or pyarrow_requirement_message)
0422 class MaxResultArrowTests(unittest.TestCase):
0423
0424
0425
0426 @classmethod
0427 def setUpClass(cls):
0428 cls.spark = SparkSession(SparkContext(
0429 'local[4]', cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k")))
0430
0431
0432 cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
0433 cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
0434
0435 @classmethod
0436 def tearDownClass(cls):
0437 if hasattr(cls, "spark"):
0438 cls.spark.stop()
0439
0440 def test_exception_by_max_results(self):
0441 with self.assertRaisesRegexp(Exception, "is bigger than"):
0442 self.spark.range(0, 10000, 1, 100).toPandas()
0443
0444
0445 class EncryptionArrowTests(ArrowTests):
0446
0447 @classmethod
0448 def conf(cls):
0449 return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
0450
0451
0452 if __name__ == "__main__":
0453 from pyspark.sql.tests.test_arrow import *
0454
0455 try:
0456 import xmlrunner
0457 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0458 except ImportError:
0459 testRunner = None
0460 unittest.main(testRunner=testRunner, verbosity=2)