Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import sys
0018 
0019 from pyspark import since
0020 from pyspark.rdd import PythonEvalType
0021 
0022 
0023 class PandasMapOpsMixin(object):
0024     """
0025     Min-in for pandas map operations. Currently, only :class:`DataFrame`
0026     can use this class.
0027     """
0028 
0029     @since(3.0)
0030     def mapInPandas(self, func, schema):
0031         """
0032         Maps an iterator of batches in the current :class:`DataFrame` using a Python native
0033         function that takes and outputs a pandas DataFrame, and returns the result as a
0034         :class:`DataFrame`.
0035 
0036         The function should take an iterator of `pandas.DataFrame`\\s and return
0037         another iterator of `pandas.DataFrame`\\s. All columns are passed
0038         together as an iterator of `pandas.DataFrame`\\s to the function and the
0039         returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
0040         Each `pandas.DataFrame` size can be controlled by
0041         `spark.sql.execution.arrow.maxRecordsPerBatch`.
0042 
0043         :param func: a Python native function that takes an iterator of `pandas.DataFrame`\\s, and
0044             outputs an iterator of `pandas.DataFrame`\\s.
0045         :param schema: the return type of the `func` in PySpark. The value can be either a
0046             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0047 
0048         >>> from pyspark.sql.functions import pandas_udf
0049         >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
0050         >>> def filter_func(iterator):
0051         ...     for pdf in iterator:
0052         ...         yield pdf[pdf.id == 1]
0053         >>> df.mapInPandas(filter_func, df.schema).show()  # doctest: +SKIP
0054         +---+---+
0055         | id|age|
0056         +---+---+
0057         |  1| 21|
0058         +---+---+
0059 
0060         .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
0061 
0062         .. note:: Experimental
0063         """
0064         from pyspark.sql import DataFrame
0065         from pyspark.sql.pandas.functions import pandas_udf
0066 
0067         assert isinstance(self, DataFrame)
0068 
0069         udf = pandas_udf(
0070             func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
0071         udf_column = udf(*[self[col] for col in self.columns])
0072         jdf = self._jdf.mapInPandas(udf_column._jc.expr())
0073         return DataFrame(jdf, self.sql_ctx)
0074 
0075 
0076 def _test():
0077     import doctest
0078     from pyspark.sql import SparkSession
0079     import pyspark.sql.pandas.map_ops
0080     globs = pyspark.sql.pandas.map_ops.__dict__.copy()
0081     spark = SparkSession.builder\
0082         .master("local[4]")\
0083         .appName("sql.pandas.map_ops tests")\
0084         .getOrCreate()
0085     globs['spark'] = spark
0086     (failure_count, test_count) = doctest.testmod(
0087         pyspark.sql.pandas.map_ops, globs=globs,
0088         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0089     spark.stop()
0090     if failure_count:
0091         sys.exit(-1)
0092 
0093 
0094 if __name__ == "__main__":
0095     _test()