0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import sys
0018 import warnings
0019 if sys.version >= '3':
0020 basestring = unicode = str
0021 xrange = range
0022 else:
0023 from itertools import izip as zip
0024 from collections import Counter
0025
0026 from pyspark import since
0027 from pyspark.rdd import _load_from_socket
0028 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
0029 from pyspark.sql.types import IntegralType
0030 from pyspark.sql.types import *
0031 from pyspark.traceback_utils import SCCallSiteSync
0032 from pyspark.util import _exception_message
0033
0034
0035 class PandasConversionMixin(object):
0036 """
0037 Min-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame`
0038 can use this class.
0039 """
0040
0041 @since(1.3)
0042 def toPandas(self):
0043 """
0044 Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
0045
0046 This is only available if Pandas is installed and available.
0047
0048 .. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is
0049 expected to be small, as all the data is loaded into the driver's memory.
0050
0051 .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
0052
0053 >>> df.toPandas() # doctest: +SKIP
0054 age name
0055 0 2 Alice
0056 1 5 Bob
0057 """
0058 from pyspark.sql.dataframe import DataFrame
0059
0060 assert isinstance(self, DataFrame)
0061
0062 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0063 require_minimum_pandas_version()
0064
0065 import numpy as np
0066 import pandas as pd
0067
0068 timezone = self.sql_ctx._conf.sessionLocalTimeZone()
0069
0070 if self.sql_ctx._conf.arrowPySparkEnabled():
0071 use_arrow = True
0072 try:
0073 from pyspark.sql.pandas.types import to_arrow_schema
0074 from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
0075
0076 require_minimum_pyarrow_version()
0077 to_arrow_schema(self.schema)
0078 except Exception as e:
0079
0080 if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
0081 msg = (
0082 "toPandas attempted Arrow optimization because "
0083 "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
0084 "failed by the reason below:\n %s\n"
0085 "Attempting non-optimization as "
0086 "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
0087 "true." % _exception_message(e))
0088 warnings.warn(msg)
0089 use_arrow = False
0090 else:
0091 msg = (
0092 "toPandas attempted Arrow optimization because "
0093 "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
0094 "reached the error below and will not continue because automatic fallback "
0095 "with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to "
0096 "false.\n %s" % _exception_message(e))
0097 warnings.warn(msg)
0098 raise
0099
0100
0101
0102 if use_arrow:
0103 try:
0104 from pyspark.sql.pandas.types import _check_series_localize_timestamps
0105 import pyarrow
0106
0107 tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
0108 batches = self.toDF(*tmp_column_names)._collect_as_arrow()
0109 if len(batches) > 0:
0110 table = pyarrow.Table.from_batches(batches)
0111
0112
0113
0114 pdf = table.to_pandas(date_as_object=True)
0115
0116 pdf.columns = self.columns
0117 for field in self.schema:
0118 if isinstance(field.dataType, TimestampType):
0119 pdf[field.name] = \
0120 _check_series_localize_timestamps(pdf[field.name], timezone)
0121 return pdf
0122 else:
0123 return pd.DataFrame.from_records([], columns=self.columns)
0124 except Exception as e:
0125
0126
0127 msg = (
0128 "toPandas attempted Arrow optimization because "
0129 "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
0130 "reached the error below and can not continue. Note that "
0131 "'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an "
0132 "effect on failures in the middle of "
0133 "computation.\n %s" % _exception_message(e))
0134 warnings.warn(msg)
0135 raise
0136
0137
0138 pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
0139 column_counter = Counter(self.columns)
0140
0141 dtype = [None] * len(self.schema)
0142 for fieldIdx, field in enumerate(self.schema):
0143
0144 if column_counter[field.name] > 1:
0145 pandas_col = pdf.iloc[:, fieldIdx]
0146 else:
0147 pandas_col = pdf[field.name]
0148
0149 pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType)
0150
0151
0152
0153
0154 if pandas_type is not None and \
0155 not(isinstance(field.dataType, IntegralType) and field.nullable and
0156 pandas_col.isnull().any()):
0157 dtype[fieldIdx] = pandas_type
0158
0159 if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any():
0160 dtype[fieldIdx] = np.float64
0161 if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any():
0162 dtype[fieldIdx] = np.object
0163
0164 df = pd.DataFrame()
0165 for index, t in enumerate(dtype):
0166 column_name = self.schema[index].name
0167
0168
0169 if column_counter[column_name] > 1:
0170 series = pdf.iloc[:, index]
0171 else:
0172 series = pdf[column_name]
0173
0174 if t is not None:
0175 series = series.astype(t, copy=False)
0176
0177
0178
0179
0180 if column_counter[column_name] > 1:
0181 df.insert(index, column_name, series, allow_duplicates=True)
0182 else:
0183 df[column_name] = series
0184
0185 pdf = df
0186
0187 if timezone is None:
0188 return pdf
0189 else:
0190 from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
0191 for field in self.schema:
0192
0193 if isinstance(field.dataType, TimestampType):
0194 pdf[field.name] = \
0195 _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
0196 return pdf
0197
0198 @staticmethod
0199 def _to_corrected_pandas_type(dt):
0200 """
0201 When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type
0202 may be wrong. This method gets the corrected data type for Pandas if that type may be
0203 inferred incorrectly.
0204 """
0205 import numpy as np
0206 if type(dt) == ByteType:
0207 return np.int8
0208 elif type(dt) == ShortType:
0209 return np.int16
0210 elif type(dt) == IntegerType:
0211 return np.int32
0212 elif type(dt) == LongType:
0213 return np.int64
0214 elif type(dt) == FloatType:
0215 return np.float32
0216 elif type(dt) == DoubleType:
0217 return np.float64
0218 elif type(dt) == BooleanType:
0219 return np.bool
0220 elif type(dt) == TimestampType:
0221 return np.datetime64
0222 else:
0223 return None
0224
0225 def _collect_as_arrow(self):
0226 """
0227 Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
0228 and available on driver and worker Python environments.
0229
0230 .. note:: Experimental.
0231 """
0232 from pyspark.sql.dataframe import DataFrame
0233
0234 assert isinstance(self, DataFrame)
0235
0236 with SCCallSiteSync(self._sc):
0237 port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
0238
0239
0240 try:
0241 results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
0242 finally:
0243
0244 jsocket_auth_server.getResult()
0245
0246
0247 batches = results[:-1]
0248 batch_order = results[-1]
0249
0250
0251 return [batches[i] for i in batch_order]
0252
0253
0254 class SparkConversionMixin(object):
0255 """
0256 Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession`
0257 can use this class.
0258 """
0259 def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
0260 from pyspark.sql import SparkSession
0261
0262 assert isinstance(self, SparkSession)
0263
0264 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0265 require_minimum_pandas_version()
0266
0267 timezone = self._wrapped._conf.sessionLocalTimeZone()
0268
0269
0270 if schema is None:
0271 schema = [str(x) if not isinstance(x, basestring) else
0272 (x.encode('utf-8') if not isinstance(x, str) else x)
0273 for x in data.columns]
0274
0275 if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
0276 try:
0277 return self._create_from_pandas_with_arrow(data, schema, timezone)
0278 except Exception as e:
0279 from pyspark.util import _exception_message
0280
0281 if self._wrapped._conf.arrowPySparkFallbackEnabled():
0282 msg = (
0283 "createDataFrame attempted Arrow optimization because "
0284 "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
0285 "failed by the reason below:\n %s\n"
0286 "Attempting non-optimization as "
0287 "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
0288 "true." % _exception_message(e))
0289 warnings.warn(msg)
0290 else:
0291 msg = (
0292 "createDataFrame attempted Arrow optimization because "
0293 "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
0294 "reached the error below and will not continue because automatic "
0295 "fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' "
0296 "has been set to false.\n %s" % _exception_message(e))
0297 warnings.warn(msg)
0298 raise
0299 data = self._convert_from_pandas(data, schema, timezone)
0300 return self._create_dataframe(data, schema, samplingRatio, verifySchema)
0301
0302 def _convert_from_pandas(self, pdf, schema, timezone):
0303 """
0304 Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
0305 :return list of records
0306 """
0307 from pyspark.sql import SparkSession
0308
0309 assert isinstance(self, SparkSession)
0310
0311 if timezone is not None:
0312 from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
0313 copied = False
0314 if isinstance(schema, StructType):
0315 for field in schema:
0316
0317 if isinstance(field.dataType, TimestampType):
0318 s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
0319 if s is not pdf[field.name]:
0320 if not copied:
0321
0322
0323 pdf = pdf.copy()
0324 copied = True
0325 pdf[field.name] = s
0326 else:
0327 for column, series in pdf.iteritems():
0328 s = _check_series_convert_timestamps_tz_local(series, timezone)
0329 if s is not series:
0330 if not copied:
0331
0332
0333 pdf = pdf.copy()
0334 copied = True
0335 pdf[column] = s
0336
0337
0338 np_records = pdf.to_records(index=False)
0339
0340
0341 if len(np_records) > 0:
0342 record_dtype = self._get_numpy_record_dtype(np_records[0])
0343 if record_dtype is not None:
0344 return [r.astype(record_dtype).tolist() for r in np_records]
0345
0346
0347 return [r.tolist() for r in np_records]
0348
0349 def _get_numpy_record_dtype(self, rec):
0350 """
0351 Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
0352 the dtypes of fields in a record so they can be properly loaded into Spark.
0353 :param rec: a numpy record to check field dtypes
0354 :return corrected dtype for a numpy.record or None if no correction needed
0355 """
0356 import numpy as np
0357 cur_dtypes = rec.dtype
0358 col_names = cur_dtypes.names
0359 record_type_list = []
0360 has_rec_fix = False
0361 for i in xrange(len(cur_dtypes)):
0362 curr_type = cur_dtypes[i]
0363
0364
0365
0366 if curr_type == np.dtype('datetime64[ns]'):
0367 curr_type = 'datetime64[us]'
0368 has_rec_fix = True
0369 record_type_list.append((str(col_names[i]), curr_type))
0370 return np.dtype(record_type_list) if has_rec_fix else None
0371
0372 def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
0373 """
0374 Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
0375 to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
0376 data types will be used to coerce the data in Pandas to Arrow conversion.
0377 """
0378 from pyspark.sql import SparkSession
0379 from pyspark.sql.dataframe import DataFrame
0380
0381 assert isinstance(self, SparkSession)
0382
0383 from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
0384 from pyspark.sql.types import TimestampType
0385 from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
0386 from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
0387 require_minimum_pyarrow_version
0388
0389 require_minimum_pandas_version()
0390 require_minimum_pyarrow_version()
0391
0392 from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
0393 import pyarrow as pa
0394
0395
0396 if isinstance(schema, (list, tuple)):
0397 arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
0398 struct = StructType()
0399 for name, field in zip(schema, arrow_schema):
0400 struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
0401 schema = struct
0402
0403
0404 if isinstance(schema, StructType):
0405 arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
0406 elif isinstance(schema, DataType):
0407 raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
0408 else:
0409
0410 arrow_types = [to_arrow_type(TimestampType())
0411 if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
0412 for t in pdf.dtypes]
0413
0414
0415 step = -(-len(pdf) // self.sparkContext.defaultParallelism)
0416 pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
0417
0418
0419 arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
0420 for pdf_slice in pdf_slices]
0421
0422 jsqlContext = self._wrapped._jsqlContext
0423
0424 safecheck = self._wrapped._conf.arrowSafeTypeConversion()
0425 col_by_name = True
0426 ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
0427
0428 def reader_func(temp_filename):
0429 return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
0430
0431 def create_RDD_server():
0432 return self._jvm.ArrowRDDServer(jsqlContext)
0433
0434
0435 jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
0436 jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
0437 df = DataFrame(jdf, self._wrapped)
0438 df._schema = schema
0439 return df
0440
0441
0442 def _test():
0443 import doctest
0444 from pyspark.sql import SparkSession
0445 import pyspark.sql.pandas.conversion
0446 globs = pyspark.sql.pandas.conversion.__dict__.copy()
0447 spark = SparkSession.builder\
0448 .master("local[4]")\
0449 .appName("sql.pandas.conversion tests")\
0450 .getOrCreate()
0451 globs['spark'] = spark
0452 (failure_count, test_count) = doctest.testmod(
0453 pyspark.sql.pandas.conversion, globs=globs,
0454 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0455 spark.stop()
0456 if failure_count:
0457 sys.exit(-1)
0458
0459
0460 if __name__ == "__main__":
0461 _test()