Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import sys
0019 import time
0020 import unittest
0021 
0022 if sys.version >= '3':
0023     unicode = str
0024 
0025 from pyspark.sql.functions import pandas_udf, PandasUDFType
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0027     pandas_requirement_message, pyarrow_requirement_message
0028 
0029 if have_pandas:
0030     import pandas as pd
0031 
0032 
0033 @unittest.skipIf(
0034     not have_pandas or not have_pyarrow,
0035     pandas_requirement_message or pyarrow_requirement_message)
0036 class MapInPandasTests(ReusedSQLTestCase):
0037 
0038     @classmethod
0039     def setUpClass(cls):
0040         ReusedSQLTestCase.setUpClass()
0041 
0042         # Synchronize default timezone between Python and Java
0043         cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
0044         tz = "America/Los_Angeles"
0045         os.environ["TZ"] = tz
0046         time.tzset()
0047 
0048         cls.sc.environment["TZ"] = tz
0049         cls.spark.conf.set("spark.sql.session.timeZone", tz)
0050 
0051     @classmethod
0052     def tearDownClass(cls):
0053         del os.environ["TZ"]
0054         if cls.tz_prev is not None:
0055             os.environ["TZ"] = cls.tz_prev
0056         time.tzset()
0057         ReusedSQLTestCase.tearDownClass()
0058 
0059     def test_map_partitions_in_pandas(self):
0060         def func(iterator):
0061             for pdf in iterator:
0062                 assert isinstance(pdf, pd.DataFrame)
0063                 assert pdf.columns == ['id']
0064                 yield pdf
0065 
0066         df = self.spark.range(10)
0067         actual = df.mapInPandas(func, 'id long').collect()
0068         expected = df.collect()
0069         self.assertEquals(actual, expected)
0070 
0071     def test_multiple_columns(self):
0072         data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
0073         df = self.spark.createDataFrame(data, "a int, b string")
0074 
0075         def func(iterator):
0076             for pdf in iterator:
0077                 assert isinstance(pdf, pd.DataFrame)
0078                 assert [d.name for d in list(pdf.dtypes)] == ['int32', 'object']
0079                 yield pdf
0080 
0081         actual = df.mapInPandas(func, df.schema).collect()
0082         expected = df.collect()
0083         self.assertEquals(actual, expected)
0084 
0085     def test_different_output_length(self):
0086         def func(iterator):
0087             for _ in iterator:
0088                 yield pd.DataFrame({'a': list(range(100))})
0089 
0090         df = self.spark.range(10)
0091         actual = df.repartition(1).mapInPandas(func, 'a long').collect()
0092         self.assertEquals(set((r.a for r in actual)), set(range(100)))
0093 
0094     def test_empty_iterator(self):
0095         def empty_iter(_):
0096             return iter([])
0097 
0098         self.assertEqual(
0099             self.spark.range(10).mapInPandas(empty_iter, 'a int, b string').count(), 0)
0100 
0101     def test_empty_rows(self):
0102         def empty_rows(_):
0103             return iter([pd.DataFrame({'a': []})])
0104 
0105         self.assertEqual(
0106             self.spark.range(10).mapInPandas(empty_rows, 'a int').count(), 0)
0107 
0108     def test_chain_map_partitions_in_pandas(self):
0109         def func(iterator):
0110             for pdf in iterator:
0111                 assert isinstance(pdf, pd.DataFrame)
0112                 assert pdf.columns == ['id']
0113                 yield pdf
0114 
0115         df = self.spark.range(10)
0116         actual = df.mapInPandas(func, 'id long').mapInPandas(func, 'id long').collect()
0117         expected = df.collect()
0118         self.assertEquals(actual, expected)
0119 
0120 
0121 if __name__ == "__main__":
0122     from pyspark.sql.tests.test_pandas_map import *
0123 
0124     try:
0125         import xmlrunner
0126         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0127     except ImportError:
0128         testRunner = None
0129     unittest.main(testRunner=testRunner, verbosity=2)