Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
0020 """
0021 
0022 import sys
0023 if sys.version < '3':
0024     from itertools import izip as zip
0025 else:
0026     basestring = unicode = str
0027     xrange = range
0028 
0029 from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer
0030 
0031 
0032 class SpecialLengths(object):
0033     END_OF_DATA_SECTION = -1
0034     PYTHON_EXCEPTION_THROWN = -2
0035     TIMING_DATA = -3
0036     END_OF_STREAM = -4
0037     NULL = -5
0038     START_ARROW_STREAM = -6
0039 
0040 
0041 class ArrowCollectSerializer(Serializer):
0042     """
0043     Deserialize a stream of batches followed by batch order information. Used in
0044     PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython()
0045     in the JVM.
0046     """
0047 
0048     def __init__(self):
0049         self.serializer = ArrowStreamSerializer()
0050 
0051     def dump_stream(self, iterator, stream):
0052         return self.serializer.dump_stream(iterator, stream)
0053 
0054     def load_stream(self, stream):
0055         """
0056         Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
0057         a list of indices that can be used to put the RecordBatches in the correct order.
0058         """
0059         # load the batches
0060         for batch in self.serializer.load_stream(stream):
0061             yield batch
0062 
0063         # load the batch order indices or propagate any error that occurred in the JVM
0064         num = read_int(stream)
0065         if num == -1:
0066             error_msg = UTF8Deserializer().loads(stream)
0067             raise RuntimeError("An error occurred while calling "
0068                                "ArrowCollectSerializer.load_stream: {}".format(error_msg))
0069         batch_order = []
0070         for i in xrange(num):
0071             index = read_int(stream)
0072             batch_order.append(index)
0073         yield batch_order
0074 
0075     def __repr__(self):
0076         return "ArrowCollectSerializer(%s)" % self.serializer
0077 
0078 
0079 class ArrowStreamSerializer(Serializer):
0080     """
0081     Serializes Arrow record batches as a stream.
0082     """
0083 
0084     def dump_stream(self, iterator, stream):
0085         import pyarrow as pa
0086         writer = None
0087         try:
0088             for batch in iterator:
0089                 if writer is None:
0090                     writer = pa.RecordBatchStreamWriter(stream, batch.schema)
0091                 writer.write_batch(batch)
0092         finally:
0093             if writer is not None:
0094                 writer.close()
0095 
0096     def load_stream(self, stream):
0097         import pyarrow as pa
0098         reader = pa.ipc.open_stream(stream)
0099         for batch in reader:
0100             yield batch
0101 
0102     def __repr__(self):
0103         return "ArrowStreamSerializer"
0104 
0105 
0106 class ArrowStreamPandasSerializer(ArrowStreamSerializer):
0107     """
0108     Serializes Pandas.Series as Arrow data with Arrow streaming format.
0109 
0110     :param timezone: A timezone to respect when handling timestamp values
0111     :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
0112     :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
0113     """
0114 
0115     def __init__(self, timezone, safecheck, assign_cols_by_name):
0116         super(ArrowStreamPandasSerializer, self).__init__()
0117         self._timezone = timezone
0118         self._safecheck = safecheck
0119         self._assign_cols_by_name = assign_cols_by_name
0120 
0121     def arrow_to_pandas(self, arrow_column):
0122         from pyspark.sql.pandas.types import _check_series_localize_timestamps
0123         import pyarrow
0124 
0125         # If the given column is a date type column, creates a series of datetime.date directly
0126         # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
0127         # datetime64[ns] type handling.
0128         s = arrow_column.to_pandas(date_as_object=True)
0129 
0130         if pyarrow.types.is_timestamp(arrow_column.type):
0131             return _check_series_localize_timestamps(s, self._timezone)
0132         else:
0133             return s
0134 
0135     def _create_batch(self, series):
0136         """
0137         Create an Arrow record batch from the given pandas.Series or list of Series,
0138         with optional type.
0139 
0140         :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
0141         :return: Arrow RecordBatch
0142         """
0143         import pandas as pd
0144         import pyarrow as pa
0145         from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
0146         # Make input conform to [(series1, type1), (series2, type2), ...]
0147         if not isinstance(series, (list, tuple)) or \
0148                 (len(series) == 2 and isinstance(series[1], pa.DataType)):
0149             series = [series]
0150         series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
0151 
0152         def create_array(s, t):
0153             mask = s.isnull()
0154             # Ensure timestamp series are in expected form for Spark internal representation
0155             if t is not None and pa.types.is_timestamp(t):
0156                 s = _check_series_convert_timestamps_internal(s, self._timezone)
0157             try:
0158                 array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
0159             except pa.ArrowException as e:
0160                 error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
0161                             "Array (%s). It can be caused by overflows or other unsafe " + \
0162                             "conversions warned by Arrow. Arrow safe type check can be " + \
0163                             "disabled by using SQL config " + \
0164                             "`spark.sql.execution.pandas.convertToArrowArraySafely`."
0165                 raise RuntimeError(error_msg % (s.dtype, t), e)
0166             return array
0167 
0168         arrs = []
0169         for s, t in series:
0170             if t is not None and pa.types.is_struct(t):
0171                 if not isinstance(s, pd.DataFrame):
0172                     raise ValueError("A field of type StructType expects a pandas.DataFrame, "
0173                                      "but got: %s" % str(type(s)))
0174 
0175                 # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
0176                 if len(s) == 0 and len(s.columns) == 0:
0177                     arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
0178                 # Assign result columns by schema name if user labeled with strings
0179                 elif self._assign_cols_by_name and any(isinstance(name, basestring)
0180                                                        for name in s.columns):
0181                     arrs_names = [(create_array(s[field.name], field.type), field.name)
0182                                   for field in t]
0183                 # Assign result columns by  position
0184                 else:
0185                     arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
0186                                   for i, field in enumerate(t)]
0187 
0188                 struct_arrs, struct_names = zip(*arrs_names)
0189                 arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
0190             else:
0191                 arrs.append(create_array(s, t))
0192 
0193         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
0194 
0195     def dump_stream(self, iterator, stream):
0196         """
0197         Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
0198         a list of series accompanied by an optional pyarrow type to coerce the data to.
0199         """
0200         batches = (self._create_batch(series) for series in iterator)
0201         super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
0202 
0203     def load_stream(self, stream):
0204         """
0205         Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
0206         """
0207         batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
0208         import pyarrow as pa
0209         for batch in batches:
0210             yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
0211 
0212     def __repr__(self):
0213         return "ArrowStreamPandasSerializer"
0214 
0215 
0216 class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
0217     """
0218     Serializer used by Python worker to evaluate Pandas UDFs
0219     """
0220 
0221     def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
0222         super(ArrowStreamPandasUDFSerializer, self) \
0223             .__init__(timezone, safecheck, assign_cols_by_name)
0224         self._df_for_struct = df_for_struct
0225 
0226     def arrow_to_pandas(self, arrow_column):
0227         import pyarrow.types as types
0228 
0229         if self._df_for_struct and types.is_struct(arrow_column.type):
0230             import pandas as pd
0231             series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
0232                       .rename(field.name)
0233                       for column, field in zip(arrow_column.flatten(), arrow_column.type)]
0234             s = pd.concat(series, axis=1)
0235         else:
0236             s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
0237         return s
0238 
0239     def dump_stream(self, iterator, stream):
0240         """
0241         Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
0242         This should be sent after creating the first record batch so in case of an error, it can
0243         be sent back to the JVM before the Arrow stream starts.
0244         """
0245 
0246         def init_stream_yield_batches():
0247             should_write_start_length = True
0248             for series in iterator:
0249                 batch = self._create_batch(series)
0250                 if should_write_start_length:
0251                     write_int(SpecialLengths.START_ARROW_STREAM, stream)
0252                     should_write_start_length = False
0253                 yield batch
0254 
0255         return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
0256 
0257     def __repr__(self):
0258         return "ArrowStreamPandasUDFSerializer"
0259 
0260 
0261 class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
0262 
0263     def load_stream(self, stream):
0264         """
0265         Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
0266         lists of pandas.Series.
0267         """
0268         import pyarrow as pa
0269         dataframes_in_group = None
0270 
0271         while dataframes_in_group is None or dataframes_in_group > 0:
0272             dataframes_in_group = read_int(stream)
0273 
0274             if dataframes_in_group == 2:
0275                 batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
0276                 batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
0277                 yield (
0278                     [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
0279                     [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
0280                 )
0281 
0282             elif dataframes_in_group != 0:
0283                 raise ValueError(
0284                     'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))