0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
0020 """
0021
0022 import sys
0023 if sys.version < '3':
0024 from itertools import izip as zip
0025 else:
0026 basestring = unicode = str
0027 xrange = range
0028
0029 from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer
0030
0031
0032 class SpecialLengths(object):
0033 END_OF_DATA_SECTION = -1
0034 PYTHON_EXCEPTION_THROWN = -2
0035 TIMING_DATA = -3
0036 END_OF_STREAM = -4
0037 NULL = -5
0038 START_ARROW_STREAM = -6
0039
0040
0041 class ArrowCollectSerializer(Serializer):
0042 """
0043 Deserialize a stream of batches followed by batch order information. Used in
0044 PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython()
0045 in the JVM.
0046 """
0047
0048 def __init__(self):
0049 self.serializer = ArrowStreamSerializer()
0050
0051 def dump_stream(self, iterator, stream):
0052 return self.serializer.dump_stream(iterator, stream)
0053
0054 def load_stream(self, stream):
0055 """
0056 Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
0057 a list of indices that can be used to put the RecordBatches in the correct order.
0058 """
0059
0060 for batch in self.serializer.load_stream(stream):
0061 yield batch
0062
0063
0064 num = read_int(stream)
0065 if num == -1:
0066 error_msg = UTF8Deserializer().loads(stream)
0067 raise RuntimeError("An error occurred while calling "
0068 "ArrowCollectSerializer.load_stream: {}".format(error_msg))
0069 batch_order = []
0070 for i in xrange(num):
0071 index = read_int(stream)
0072 batch_order.append(index)
0073 yield batch_order
0074
0075 def __repr__(self):
0076 return "ArrowCollectSerializer(%s)" % self.serializer
0077
0078
0079 class ArrowStreamSerializer(Serializer):
0080 """
0081 Serializes Arrow record batches as a stream.
0082 """
0083
0084 def dump_stream(self, iterator, stream):
0085 import pyarrow as pa
0086 writer = None
0087 try:
0088 for batch in iterator:
0089 if writer is None:
0090 writer = pa.RecordBatchStreamWriter(stream, batch.schema)
0091 writer.write_batch(batch)
0092 finally:
0093 if writer is not None:
0094 writer.close()
0095
0096 def load_stream(self, stream):
0097 import pyarrow as pa
0098 reader = pa.ipc.open_stream(stream)
0099 for batch in reader:
0100 yield batch
0101
0102 def __repr__(self):
0103 return "ArrowStreamSerializer"
0104
0105
0106 class ArrowStreamPandasSerializer(ArrowStreamSerializer):
0107 """
0108 Serializes Pandas.Series as Arrow data with Arrow streaming format.
0109
0110 :param timezone: A timezone to respect when handling timestamp values
0111 :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
0112 :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
0113 """
0114
0115 def __init__(self, timezone, safecheck, assign_cols_by_name):
0116 super(ArrowStreamPandasSerializer, self).__init__()
0117 self._timezone = timezone
0118 self._safecheck = safecheck
0119 self._assign_cols_by_name = assign_cols_by_name
0120
0121 def arrow_to_pandas(self, arrow_column):
0122 from pyspark.sql.pandas.types import _check_series_localize_timestamps
0123 import pyarrow
0124
0125
0126
0127
0128 s = arrow_column.to_pandas(date_as_object=True)
0129
0130 if pyarrow.types.is_timestamp(arrow_column.type):
0131 return _check_series_localize_timestamps(s, self._timezone)
0132 else:
0133 return s
0134
0135 def _create_batch(self, series):
0136 """
0137 Create an Arrow record batch from the given pandas.Series or list of Series,
0138 with optional type.
0139
0140 :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
0141 :return: Arrow RecordBatch
0142 """
0143 import pandas as pd
0144 import pyarrow as pa
0145 from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
0146
0147 if not isinstance(series, (list, tuple)) or \
0148 (len(series) == 2 and isinstance(series[1], pa.DataType)):
0149 series = [series]
0150 series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
0151
0152 def create_array(s, t):
0153 mask = s.isnull()
0154
0155 if t is not None and pa.types.is_timestamp(t):
0156 s = _check_series_convert_timestamps_internal(s, self._timezone)
0157 try:
0158 array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
0159 except pa.ArrowException as e:
0160 error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
0161 "Array (%s). It can be caused by overflows or other unsafe " + \
0162 "conversions warned by Arrow. Arrow safe type check can be " + \
0163 "disabled by using SQL config " + \
0164 "`spark.sql.execution.pandas.convertToArrowArraySafely`."
0165 raise RuntimeError(error_msg % (s.dtype, t), e)
0166 return array
0167
0168 arrs = []
0169 for s, t in series:
0170 if t is not None and pa.types.is_struct(t):
0171 if not isinstance(s, pd.DataFrame):
0172 raise ValueError("A field of type StructType expects a pandas.DataFrame, "
0173 "but got: %s" % str(type(s)))
0174
0175
0176 if len(s) == 0 and len(s.columns) == 0:
0177 arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
0178
0179 elif self._assign_cols_by_name and any(isinstance(name, basestring)
0180 for name in s.columns):
0181 arrs_names = [(create_array(s[field.name], field.type), field.name)
0182 for field in t]
0183
0184 else:
0185 arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
0186 for i, field in enumerate(t)]
0187
0188 struct_arrs, struct_names = zip(*arrs_names)
0189 arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
0190 else:
0191 arrs.append(create_array(s, t))
0192
0193 return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
0194
0195 def dump_stream(self, iterator, stream):
0196 """
0197 Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
0198 a list of series accompanied by an optional pyarrow type to coerce the data to.
0199 """
0200 batches = (self._create_batch(series) for series in iterator)
0201 super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
0202
0203 def load_stream(self, stream):
0204 """
0205 Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
0206 """
0207 batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
0208 import pyarrow as pa
0209 for batch in batches:
0210 yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
0211
0212 def __repr__(self):
0213 return "ArrowStreamPandasSerializer"
0214
0215
0216 class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
0217 """
0218 Serializer used by Python worker to evaluate Pandas UDFs
0219 """
0220
0221 def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
0222 super(ArrowStreamPandasUDFSerializer, self) \
0223 .__init__(timezone, safecheck, assign_cols_by_name)
0224 self._df_for_struct = df_for_struct
0225
0226 def arrow_to_pandas(self, arrow_column):
0227 import pyarrow.types as types
0228
0229 if self._df_for_struct and types.is_struct(arrow_column.type):
0230 import pandas as pd
0231 series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
0232 .rename(field.name)
0233 for column, field in zip(arrow_column.flatten(), arrow_column.type)]
0234 s = pd.concat(series, axis=1)
0235 else:
0236 s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
0237 return s
0238
0239 def dump_stream(self, iterator, stream):
0240 """
0241 Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
0242 This should be sent after creating the first record batch so in case of an error, it can
0243 be sent back to the JVM before the Arrow stream starts.
0244 """
0245
0246 def init_stream_yield_batches():
0247 should_write_start_length = True
0248 for series in iterator:
0249 batch = self._create_batch(series)
0250 if should_write_start_length:
0251 write_int(SpecialLengths.START_ARROW_STREAM, stream)
0252 should_write_start_length = False
0253 yield batch
0254
0255 return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
0256
0257 def __repr__(self):
0258 return "ArrowStreamPandasUDFSerializer"
0259
0260
0261 class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
0262
0263 def load_stream(self, stream):
0264 """
0265 Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
0266 lists of pandas.Series.
0267 """
0268 import pyarrow as pa
0269 dataframes_in_group = None
0270
0271 while dataframes_in_group is None or dataframes_in_group > 0:
0272 dataframes_in_group = read_int(stream)
0273
0274 if dataframes_in_group == 2:
0275 batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
0276 batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
0277 yield (
0278 [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
0279 [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
0280 )
0281
0282 elif dataframes_in_group != 0:
0283 raise ValueError(
0284 'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))