0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Type-specific codes between pandas and PyArrow. Also contains some utils to correct
0020 pandas instances during the type conversion.
0021 """
0022
0023 from pyspark.sql.types import *
0024
0025
0026 def to_arrow_type(dt):
0027 """ Convert Spark data type to pyarrow type
0028 """
0029 import pyarrow as pa
0030 if type(dt) == BooleanType:
0031 arrow_type = pa.bool_()
0032 elif type(dt) == ByteType:
0033 arrow_type = pa.int8()
0034 elif type(dt) == ShortType:
0035 arrow_type = pa.int16()
0036 elif type(dt) == IntegerType:
0037 arrow_type = pa.int32()
0038 elif type(dt) == LongType:
0039 arrow_type = pa.int64()
0040 elif type(dt) == FloatType:
0041 arrow_type = pa.float32()
0042 elif type(dt) == DoubleType:
0043 arrow_type = pa.float64()
0044 elif type(dt) == DecimalType:
0045 arrow_type = pa.decimal128(dt.precision, dt.scale)
0046 elif type(dt) == StringType:
0047 arrow_type = pa.string()
0048 elif type(dt) == BinaryType:
0049 arrow_type = pa.binary()
0050 elif type(dt) == DateType:
0051 arrow_type = pa.date32()
0052 elif type(dt) == TimestampType:
0053
0054 arrow_type = pa.timestamp('us', tz='UTC')
0055 elif type(dt) == ArrayType:
0056 if type(dt.elementType) in [StructType, TimestampType]:
0057 raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
0058 arrow_type = pa.list_(to_arrow_type(dt.elementType))
0059 elif type(dt) == StructType:
0060 if any(type(field.dataType) == StructType for field in dt):
0061 raise TypeError("Nested StructType not supported in conversion to Arrow")
0062 fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
0063 for field in dt]
0064 arrow_type = pa.struct(fields)
0065 else:
0066 raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
0067 return arrow_type
0068
0069
0070 def to_arrow_schema(schema):
0071 """ Convert a schema from Spark to Arrow
0072 """
0073 import pyarrow as pa
0074 fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
0075 for field in schema]
0076 return pa.schema(fields)
0077
0078
0079 def from_arrow_type(at):
0080 """ Convert pyarrow type to Spark data type.
0081 """
0082 import pyarrow.types as types
0083 if types.is_boolean(at):
0084 spark_type = BooleanType()
0085 elif types.is_int8(at):
0086 spark_type = ByteType()
0087 elif types.is_int16(at):
0088 spark_type = ShortType()
0089 elif types.is_int32(at):
0090 spark_type = IntegerType()
0091 elif types.is_int64(at):
0092 spark_type = LongType()
0093 elif types.is_float32(at):
0094 spark_type = FloatType()
0095 elif types.is_float64(at):
0096 spark_type = DoubleType()
0097 elif types.is_decimal(at):
0098 spark_type = DecimalType(precision=at.precision, scale=at.scale)
0099 elif types.is_string(at):
0100 spark_type = StringType()
0101 elif types.is_binary(at):
0102 spark_type = BinaryType()
0103 elif types.is_date32(at):
0104 spark_type = DateType()
0105 elif types.is_timestamp(at):
0106 spark_type = TimestampType()
0107 elif types.is_list(at):
0108 if types.is_timestamp(at.value_type):
0109 raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
0110 spark_type = ArrayType(from_arrow_type(at.value_type))
0111 elif types.is_struct(at):
0112 if any(types.is_struct(field.type) for field in at):
0113 raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
0114 return StructType(
0115 [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
0116 for field in at])
0117 else:
0118 raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
0119 return spark_type
0120
0121
0122 def from_arrow_schema(arrow_schema):
0123 """ Convert schema from Arrow to Spark.
0124 """
0125 return StructType(
0126 [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
0127 for field in arrow_schema])
0128
0129
0130 def _get_local_timezone():
0131 """ Get local timezone using pytz with environment variable, or dateutil.
0132
0133 If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
0134 string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
0135 it reads system configuration to know the system local timezone.
0136
0137 See also:
0138 - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
0139 - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
0140 """
0141 import os
0142 return os.environ.get('TZ', 'dateutil/:')
0143
0144
0145 def _check_series_localize_timestamps(s, timezone):
0146 """
0147 Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
0148
0149 If the input series is not a timestamp series, then the same series is returned. If the input
0150 series is a timestamp series, then a converted series is returned.
0151
0152 :param s: pandas.Series
0153 :param timezone: the timezone to convert. if None then use local timezone
0154 :return pandas.Series that have been converted to tz-naive
0155 """
0156 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0157 require_minimum_pandas_version()
0158
0159 from pandas.api.types import is_datetime64tz_dtype
0160 tz = timezone or _get_local_timezone()
0161
0162 if is_datetime64tz_dtype(s.dtype):
0163 return s.dt.tz_convert(tz).dt.tz_localize(None)
0164 else:
0165 return s
0166
0167
0168 def _check_series_convert_timestamps_internal(s, timezone):
0169 """
0170 Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
0171 Spark internal storage
0172
0173 :param s: a pandas.Series
0174 :param timezone: the timezone to convert. if None then use local timezone
0175 :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
0176 """
0177 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0178 require_minimum_pandas_version()
0179
0180 from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
0181
0182 if is_datetime64_dtype(s.dtype):
0183
0184
0185
0186
0187
0188
0189
0190
0191
0192
0193
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212
0213 tz = timezone or _get_local_timezone()
0214 return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
0215 elif is_datetime64tz_dtype(s.dtype):
0216 return s.dt.tz_convert('UTC')
0217 else:
0218 return s
0219
0220
0221 def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
0222 """
0223 Convert timestamp to timezone-naive in the specified timezone or local timezone
0224
0225 :param s: a pandas.Series
0226 :param from_timezone: the timezone to convert from. if None then use local timezone
0227 :param to_timezone: the timezone to convert to. if None then use local timezone
0228 :return pandas.Series where if it is a timestamp, has been converted to tz-naive
0229 """
0230 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0231 require_minimum_pandas_version()
0232
0233 import pandas as pd
0234 from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
0235 from_tz = from_timezone or _get_local_timezone()
0236 to_tz = to_timezone or _get_local_timezone()
0237
0238 if is_datetime64tz_dtype(s.dtype):
0239 return s.dt.tz_convert(to_tz).dt.tz_localize(None)
0240 elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
0241
0242 return s.apply(
0243 lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
0244 if ts is not pd.NaT else pd.NaT)
0245 else:
0246 return s
0247
0248
0249 def _check_series_convert_timestamps_local_tz(s, timezone):
0250 """
0251 Convert timestamp to timezone-naive in the specified timezone or local timezone
0252
0253 :param s: a pandas.Series
0254 :param timezone: the timezone to convert to. if None then use local timezone
0255 :return pandas.Series where if it is a timestamp, has been converted to tz-naive
0256 """
0257 return _check_series_convert_timestamps_localize(s, None, timezone)
0258
0259
0260 def _check_series_convert_timestamps_tz_local(s, timezone):
0261 """
0262 Convert timestamp to timezone-naive in the specified timezone or local timezone
0263
0264 :param s: a pandas.Series
0265 :param timezone: the timezone to convert from. if None then use local timezone
0266 :return pandas.Series where if it is a timestamp, has been converted to tz-naive
0267 """
0268 return _check_series_convert_timestamps_localize(s, timezone, None)