0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import sys
0019 import time
0020 import unittest
0021
0022 if sys.version >= '3':
0023 unicode = str
0024
0025 from pyspark.sql.functions import pandas_udf, PandasUDFType
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0027 pandas_requirement_message, pyarrow_requirement_message
0028
0029 if have_pandas:
0030 import pandas as pd
0031
0032
0033 @unittest.skipIf(
0034 not have_pandas or not have_pyarrow,
0035 pandas_requirement_message or pyarrow_requirement_message)
0036 class MapInPandasTests(ReusedSQLTestCase):
0037
0038 @classmethod
0039 def setUpClass(cls):
0040 ReusedSQLTestCase.setUpClass()
0041
0042
0043 cls.tz_prev = os.environ.get("TZ", None)
0044 tz = "America/Los_Angeles"
0045 os.environ["TZ"] = tz
0046 time.tzset()
0047
0048 cls.sc.environment["TZ"] = tz
0049 cls.spark.conf.set("spark.sql.session.timeZone", tz)
0050
0051 @classmethod
0052 def tearDownClass(cls):
0053 del os.environ["TZ"]
0054 if cls.tz_prev is not None:
0055 os.environ["TZ"] = cls.tz_prev
0056 time.tzset()
0057 ReusedSQLTestCase.tearDownClass()
0058
0059 def test_map_partitions_in_pandas(self):
0060 def func(iterator):
0061 for pdf in iterator:
0062 assert isinstance(pdf, pd.DataFrame)
0063 assert pdf.columns == ['id']
0064 yield pdf
0065
0066 df = self.spark.range(10)
0067 actual = df.mapInPandas(func, 'id long').collect()
0068 expected = df.collect()
0069 self.assertEquals(actual, expected)
0070
0071 def test_multiple_columns(self):
0072 data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
0073 df = self.spark.createDataFrame(data, "a int, b string")
0074
0075 def func(iterator):
0076 for pdf in iterator:
0077 assert isinstance(pdf, pd.DataFrame)
0078 assert [d.name for d in list(pdf.dtypes)] == ['int32', 'object']
0079 yield pdf
0080
0081 actual = df.mapInPandas(func, df.schema).collect()
0082 expected = df.collect()
0083 self.assertEquals(actual, expected)
0084
0085 def test_different_output_length(self):
0086 def func(iterator):
0087 for _ in iterator:
0088 yield pd.DataFrame({'a': list(range(100))})
0089
0090 df = self.spark.range(10)
0091 actual = df.repartition(1).mapInPandas(func, 'a long').collect()
0092 self.assertEquals(set((r.a for r in actual)), set(range(100)))
0093
0094 def test_empty_iterator(self):
0095 def empty_iter(_):
0096 return iter([])
0097
0098 self.assertEqual(
0099 self.spark.range(10).mapInPandas(empty_iter, 'a int, b string').count(), 0)
0100
0101 def test_empty_rows(self):
0102 def empty_rows(_):
0103 return iter([pd.DataFrame({'a': []})])
0104
0105 self.assertEqual(
0106 self.spark.range(10).mapInPandas(empty_rows, 'a int').count(), 0)
0107
0108 def test_chain_map_partitions_in_pandas(self):
0109 def func(iterator):
0110 for pdf in iterator:
0111 assert isinstance(pdf, pd.DataFrame)
0112 assert pdf.columns == ['id']
0113 yield pdf
0114
0115 df = self.spark.range(10)
0116 actual = df.mapInPandas(func, 'id long').mapInPandas(func, 'id long').collect()
0117 expected = df.collect()
0118 self.assertEquals(actual, expected)
0119
0120
0121 if __name__ == "__main__":
0122 from pyspark.sql.tests.test_pandas_map import *
0123
0124 try:
0125 import xmlrunner
0126 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0127 except ImportError:
0128 testRunner = None
0129 unittest.main(testRunner=testRunner, verbosity=2)