Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 Type-specific codes between pandas and PyArrow. Also contains some utils to correct
0020 pandas instances during the type conversion.
0021 """
0022 
0023 from pyspark.sql.types import *
0024 
0025 
0026 def to_arrow_type(dt):
0027     """ Convert Spark data type to pyarrow type
0028     """
0029     import pyarrow as pa
0030     if type(dt) == BooleanType:
0031         arrow_type = pa.bool_()
0032     elif type(dt) == ByteType:
0033         arrow_type = pa.int8()
0034     elif type(dt) == ShortType:
0035         arrow_type = pa.int16()
0036     elif type(dt) == IntegerType:
0037         arrow_type = pa.int32()
0038     elif type(dt) == LongType:
0039         arrow_type = pa.int64()
0040     elif type(dt) == FloatType:
0041         arrow_type = pa.float32()
0042     elif type(dt) == DoubleType:
0043         arrow_type = pa.float64()
0044     elif type(dt) == DecimalType:
0045         arrow_type = pa.decimal128(dt.precision, dt.scale)
0046     elif type(dt) == StringType:
0047         arrow_type = pa.string()
0048     elif type(dt) == BinaryType:
0049         arrow_type = pa.binary()
0050     elif type(dt) == DateType:
0051         arrow_type = pa.date32()
0052     elif type(dt) == TimestampType:
0053         # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
0054         arrow_type = pa.timestamp('us', tz='UTC')
0055     elif type(dt) == ArrayType:
0056         if type(dt.elementType) in [StructType, TimestampType]:
0057             raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
0058         arrow_type = pa.list_(to_arrow_type(dt.elementType))
0059     elif type(dt) == StructType:
0060         if any(type(field.dataType) == StructType for field in dt):
0061             raise TypeError("Nested StructType not supported in conversion to Arrow")
0062         fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
0063                   for field in dt]
0064         arrow_type = pa.struct(fields)
0065     else:
0066         raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
0067     return arrow_type
0068 
0069 
0070 def to_arrow_schema(schema):
0071     """ Convert a schema from Spark to Arrow
0072     """
0073     import pyarrow as pa
0074     fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
0075               for field in schema]
0076     return pa.schema(fields)
0077 
0078 
0079 def from_arrow_type(at):
0080     """ Convert pyarrow type to Spark data type.
0081     """
0082     import pyarrow.types as types
0083     if types.is_boolean(at):
0084         spark_type = BooleanType()
0085     elif types.is_int8(at):
0086         spark_type = ByteType()
0087     elif types.is_int16(at):
0088         spark_type = ShortType()
0089     elif types.is_int32(at):
0090         spark_type = IntegerType()
0091     elif types.is_int64(at):
0092         spark_type = LongType()
0093     elif types.is_float32(at):
0094         spark_type = FloatType()
0095     elif types.is_float64(at):
0096         spark_type = DoubleType()
0097     elif types.is_decimal(at):
0098         spark_type = DecimalType(precision=at.precision, scale=at.scale)
0099     elif types.is_string(at):
0100         spark_type = StringType()
0101     elif types.is_binary(at):
0102         spark_type = BinaryType()
0103     elif types.is_date32(at):
0104         spark_type = DateType()
0105     elif types.is_timestamp(at):
0106         spark_type = TimestampType()
0107     elif types.is_list(at):
0108         if types.is_timestamp(at.value_type):
0109             raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
0110         spark_type = ArrayType(from_arrow_type(at.value_type))
0111     elif types.is_struct(at):
0112         if any(types.is_struct(field.type) for field in at):
0113             raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
0114         return StructType(
0115             [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
0116              for field in at])
0117     else:
0118         raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
0119     return spark_type
0120 
0121 
0122 def from_arrow_schema(arrow_schema):
0123     """ Convert schema from Arrow to Spark.
0124     """
0125     return StructType(
0126         [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
0127          for field in arrow_schema])
0128 
0129 
0130 def _get_local_timezone():
0131     """ Get local timezone using pytz with environment variable, or dateutil.
0132 
0133     If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
0134     string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
0135     it reads system configuration to know the system local timezone.
0136 
0137     See also:
0138     - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
0139     - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
0140     """
0141     import os
0142     return os.environ.get('TZ', 'dateutil/:')
0143 
0144 
0145 def _check_series_localize_timestamps(s, timezone):
0146     """
0147     Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
0148 
0149     If the input series is not a timestamp series, then the same series is returned. If the input
0150     series is a timestamp series, then a converted series is returned.
0151 
0152     :param s: pandas.Series
0153     :param timezone: the timezone to convert. if None then use local timezone
0154     :return pandas.Series that have been converted to tz-naive
0155     """
0156     from pyspark.sql.pandas.utils import require_minimum_pandas_version
0157     require_minimum_pandas_version()
0158 
0159     from pandas.api.types import is_datetime64tz_dtype
0160     tz = timezone or _get_local_timezone()
0161     # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
0162     if is_datetime64tz_dtype(s.dtype):
0163         return s.dt.tz_convert(tz).dt.tz_localize(None)
0164     else:
0165         return s
0166 
0167 
0168 def _check_series_convert_timestamps_internal(s, timezone):
0169     """
0170     Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
0171     Spark internal storage
0172 
0173     :param s: a pandas.Series
0174     :param timezone: the timezone to convert. if None then use local timezone
0175     :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
0176     """
0177     from pyspark.sql.pandas.utils import require_minimum_pandas_version
0178     require_minimum_pandas_version()
0179 
0180     from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
0181     # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
0182     if is_datetime64_dtype(s.dtype):
0183         # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
0184         # timestamp is during the hour when the clock is adjusted backward during due to
0185         # daylight saving time (dst).
0186         # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
0187         # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
0188         # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
0189         # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
0190         #
0191         # Here we explicit choose to use standard time. This matches the default behavior of
0192         # pytz.
0193         #
0194         # Here are some code to help understand this behavior:
0195         # >>> import datetime
0196         # >>> import pandas as pd
0197         # >>> import pytz
0198         # >>>
0199         # >>> t = datetime.datetime(2015, 11, 1, 1, 30)
0200         # >>> ts = pd.Series([t])
0201         # >>> tz = pytz.timezone('America/New_York')
0202         # >>>
0203         # >>> ts.dt.tz_localize(tz, ambiguous=True)
0204         # 0   2015-11-01 01:30:00-04:00
0205         # dtype: datetime64[ns, America/New_York]
0206         # >>>
0207         # >>> ts.dt.tz_localize(tz, ambiguous=False)
0208         # 0   2015-11-01 01:30:00-05:00
0209         # dtype: datetime64[ns, America/New_York]
0210         # >>>
0211         # >>> str(tz.localize(t))
0212         # '2015-11-01 01:30:00-05:00'
0213         tz = timezone or _get_local_timezone()
0214         return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
0215     elif is_datetime64tz_dtype(s.dtype):
0216         return s.dt.tz_convert('UTC')
0217     else:
0218         return s
0219 
0220 
0221 def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
0222     """
0223     Convert timestamp to timezone-naive in the specified timezone or local timezone
0224 
0225     :param s: a pandas.Series
0226     :param from_timezone: the timezone to convert from. if None then use local timezone
0227     :param to_timezone: the timezone to convert to. if None then use local timezone
0228     :return pandas.Series where if it is a timestamp, has been converted to tz-naive
0229     """
0230     from pyspark.sql.pandas.utils import require_minimum_pandas_version
0231     require_minimum_pandas_version()
0232 
0233     import pandas as pd
0234     from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
0235     from_tz = from_timezone or _get_local_timezone()
0236     to_tz = to_timezone or _get_local_timezone()
0237     # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
0238     if is_datetime64tz_dtype(s.dtype):
0239         return s.dt.tz_convert(to_tz).dt.tz_localize(None)
0240     elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
0241         # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
0242         return s.apply(
0243             lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
0244             if ts is not pd.NaT else pd.NaT)
0245     else:
0246         return s
0247 
0248 
0249 def _check_series_convert_timestamps_local_tz(s, timezone):
0250     """
0251     Convert timestamp to timezone-naive in the specified timezone or local timezone
0252 
0253     :param s: a pandas.Series
0254     :param timezone: the timezone to convert to. if None then use local timezone
0255     :return pandas.Series where if it is a timestamp, has been converted to tz-naive
0256     """
0257     return _check_series_convert_timestamps_localize(s, None, timezone)
0258 
0259 
0260 def _check_series_convert_timestamps_tz_local(s, timezone):
0261     """
0262     Convert timestamp to timezone-naive in the specified timezone or local timezone
0263 
0264     :param s: a pandas.Series
0265     :param timezone: the timezone to convert from. if None then use local timezone
0266     :return pandas.Series where if it is a timestamp, has been converted to tz-naive
0267     """
0268     return _check_series_convert_timestamps_localize(s, timezone, None)