0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0018
0019
0020 def infer_eval_type(sig):
0021 """
0022 Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from
0023 :class:`inspect.Signature` instance.
0024 """
0025 from pyspark.sql.pandas.functions import PandasUDFType
0026
0027 require_minimum_pandas_version()
0028
0029 import pandas as pd
0030
0031 annotations = {}
0032 for param in sig.parameters.values():
0033 if param.annotation is not param.empty:
0034 annotations[param.name] = param.annotation
0035
0036
0037 parameters_sig = [annotations[parameter] for parameter
0038 in sig.parameters if parameter in annotations]
0039 if len(parameters_sig) != len(sig.parameters):
0040 raise ValueError(
0041 "Type hints for all parameters should be specified; however, got %s" % sig)
0042
0043
0044 return_annotation = sig.return_annotation
0045 if sig.empty is return_annotation:
0046 raise ValueError(
0047 "Type hint for the return type should be specified; however, got %s" % sig)
0048
0049
0050 is_series_or_frame = (
0051 all(a == pd.Series or
0052 a == pd.DataFrame or
0053 check_union_annotation(
0054 a,
0055 parameter_check_func=lambda na: na == pd.Series or na == pd.DataFrame)
0056 for a in parameters_sig) and
0057 (return_annotation == pd.Series or return_annotation == pd.DataFrame))
0058
0059
0060 is_iterator_tuple_series_or_frame = (
0061 len(parameters_sig) == 1 and
0062 check_iterator_annotation(
0063 parameters_sig[0],
0064 parameter_check_func=lambda a: check_tuple_annotation(
0065 a,
0066 parameter_check_func=lambda ta: (
0067 ta == Ellipsis or
0068 ta == pd.Series or
0069 ta == pd.DataFrame or
0070 check_union_annotation(
0071 ta,
0072 parameter_check_func=lambda na: (
0073 na == pd.Series or na == pd.DataFrame))))) and
0074 check_iterator_annotation(
0075 return_annotation,
0076 parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series))
0077
0078
0079 is_iterator_series_or_frame = (
0080 len(parameters_sig) == 1 and
0081 check_iterator_annotation(
0082 parameters_sig[0],
0083 parameter_check_func=lambda a: (
0084 a == pd.Series or
0085 a == pd.DataFrame or
0086 check_union_annotation(
0087 a,
0088 parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame))) and
0089 check_iterator_annotation(
0090 return_annotation,
0091 parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series))
0092
0093
0094 is_series_or_frame_agg = (
0095 all(a == pd.Series or
0096 a == pd.DataFrame or
0097 check_union_annotation(
0098 a,
0099 parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame)
0100 for a in parameters_sig) and (
0101
0102
0103
0104 return_annotation != pd.Series and
0105 return_annotation != pd.DataFrame and
0106 not check_iterator_annotation(return_annotation) and
0107 not check_tuple_annotation(return_annotation)
0108 ))
0109
0110 if is_series_or_frame:
0111 return PandasUDFType.SCALAR
0112 elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame:
0113 return PandasUDFType.SCALAR_ITER
0114 elif is_series_or_frame_agg:
0115 return PandasUDFType.GROUPED_AGG
0116 else:
0117 raise NotImplementedError("Unsupported signature: %s." % sig)
0118
0119
0120 def check_tuple_annotation(annotation, parameter_check_func=None):
0121
0122
0123 name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
0124 return name == "Tuple" and (
0125 parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))
0126
0127
0128 def check_iterator_annotation(annotation, parameter_check_func=None):
0129 name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
0130 return name == "Iterator" and (
0131 parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))
0132
0133
0134 def check_union_annotation(annotation, parameter_check_func=None):
0135 import typing
0136
0137
0138
0139 origin = getattr(annotation, "__origin__", None)
0140 return origin == typing.Union and (
0141 parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)))