Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0018 
0019 
0020 def infer_eval_type(sig):
0021     """
0022     Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from
0023     :class:`inspect.Signature` instance.
0024     """
0025     from pyspark.sql.pandas.functions import PandasUDFType
0026 
0027     require_minimum_pandas_version()
0028 
0029     import pandas as pd
0030 
0031     annotations = {}
0032     for param in sig.parameters.values():
0033         if param.annotation is not param.empty:
0034             annotations[param.name] = param.annotation
0035 
0036     # Check if all arguments have type hints
0037     parameters_sig = [annotations[parameter] for parameter
0038                       in sig.parameters if parameter in annotations]
0039     if len(parameters_sig) != len(sig.parameters):
0040         raise ValueError(
0041             "Type hints for all parameters should be specified; however, got %s" % sig)
0042 
0043     # Check if the return has a type hint
0044     return_annotation = sig.return_annotation
0045     if sig.empty is return_annotation:
0046         raise ValueError(
0047             "Type hint for the return type should be specified; however, got %s" % sig)
0048 
0049     # Series, Frame or Union[DataFrame, Series], ... -> Series or Frame
0050     is_series_or_frame = (
0051         all(a == pd.Series or  # Series
0052             a == pd.DataFrame or  # DataFrame
0053             check_union_annotation(  # Union[DataFrame, Series]
0054                 a,
0055                 parameter_check_func=lambda na: na == pd.Series or na == pd.DataFrame)
0056             for a in parameters_sig) and
0057         (return_annotation == pd.Series or return_annotation == pd.DataFrame))
0058 
0059     # Iterator[Tuple[Series, Frame or Union[DataFrame, Series], ...] -> Iterator[Series or Frame]
0060     is_iterator_tuple_series_or_frame = (
0061         len(parameters_sig) == 1 and
0062         check_iterator_annotation(  # Iterator
0063             parameters_sig[0],
0064             parameter_check_func=lambda a: check_tuple_annotation(  # Tuple
0065                 a,
0066                 parameter_check_func=lambda ta: (
0067                     ta == Ellipsis or  # ...
0068                     ta == pd.Series or  # Series
0069                     ta == pd.DataFrame or  # DataFrame
0070                     check_union_annotation(  # Union[DataFrame, Series]
0071                         ta,
0072                         parameter_check_func=lambda na: (
0073                             na == pd.Series or na == pd.DataFrame))))) and
0074         check_iterator_annotation(
0075             return_annotation,
0076             parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series))
0077 
0078     # Iterator[Series, Frame or Union[DataFrame, Series]] -> Iterator[Series or Frame]
0079     is_iterator_series_or_frame = (
0080         len(parameters_sig) == 1 and
0081         check_iterator_annotation(
0082             parameters_sig[0],
0083             parameter_check_func=lambda a: (
0084                 a == pd.Series or  # Series
0085                 a == pd.DataFrame or  # DataFrame
0086                 check_union_annotation(  # Union[DataFrame, Series]
0087                     a,
0088                     parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame))) and
0089         check_iterator_annotation(
0090             return_annotation,
0091             parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series))
0092 
0093     # Series, Frame or Union[DataFrame, Series], ... -> Any
0094     is_series_or_frame_agg = (
0095         all(a == pd.Series or  # Series
0096             a == pd.DataFrame or  # DataFrame
0097             check_union_annotation(  # Union[DataFrame, Series]
0098                 a,
0099                 parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame)
0100             for a in parameters_sig) and (
0101             # It's tricky to whitelist which types pd.Series constructor can take.
0102             # Simply blacklist common types used here for now (which becomes object
0103             # types Spark can't recognize).
0104             return_annotation != pd.Series and
0105             return_annotation != pd.DataFrame and
0106             not check_iterator_annotation(return_annotation) and
0107             not check_tuple_annotation(return_annotation)
0108         ))
0109 
0110     if is_series_or_frame:
0111         return PandasUDFType.SCALAR
0112     elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame:
0113         return PandasUDFType.SCALAR_ITER
0114     elif is_series_or_frame_agg:
0115         return PandasUDFType.GROUPED_AGG
0116     else:
0117         raise NotImplementedError("Unsupported signature: %s." % sig)
0118 
0119 
0120 def check_tuple_annotation(annotation, parameter_check_func=None):
0121     # Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`.
0122     # Check if the name is Tuple first. After that, check the generic types.
0123     name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
0124     return name == "Tuple" and (
0125         parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))
0126 
0127 
0128 def check_iterator_annotation(annotation, parameter_check_func=None):
0129     name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
0130     return name == "Iterator" and (
0131         parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))
0132 
0133 
0134 def check_union_annotation(annotation, parameter_check_func=None):
0135     import typing
0136 
0137     # Note that we cannot rely on '__origin__' in other type hints as it has changed from version
0138     # to version. For example, it's abc.Iterator in Python 3.7 but typing.Iterator in Python 3.6.
0139     origin = getattr(annotation, "__origin__", None)
0140     return origin == typing.Union and (
0141         parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))