Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import sys
0018 import warnings
0019 if sys.version >= '3':
0020     basestring = unicode = str
0021     xrange = range
0022 else:
0023     from itertools import izip as zip
0024 from collections import Counter
0025 
0026 from pyspark import since
0027 from pyspark.rdd import _load_from_socket
0028 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
0029 from pyspark.sql.types import IntegralType
0030 from pyspark.sql.types import *
0031 from pyspark.traceback_utils import SCCallSiteSync
0032 from pyspark.util import _exception_message
0033 
0034 
0035 class PandasConversionMixin(object):
0036     """
0037     Min-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame`
0038     can use this class.
0039     """
0040 
0041     @since(1.3)
0042     def toPandas(self):
0043         """
0044         Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
0045 
0046         This is only available if Pandas is installed and available.
0047 
0048         .. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is
0049             expected to be small, as all the data is loaded into the driver's memory.
0050 
0051         .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
0052 
0053         >>> df.toPandas()  # doctest: +SKIP
0054            age   name
0055         0    2  Alice
0056         1    5    Bob
0057         """
0058         from pyspark.sql.dataframe import DataFrame
0059 
0060         assert isinstance(self, DataFrame)
0061 
0062         from pyspark.sql.pandas.utils import require_minimum_pandas_version
0063         require_minimum_pandas_version()
0064 
0065         import numpy as np
0066         import pandas as pd
0067 
0068         timezone = self.sql_ctx._conf.sessionLocalTimeZone()
0069 
0070         if self.sql_ctx._conf.arrowPySparkEnabled():
0071             use_arrow = True
0072             try:
0073                 from pyspark.sql.pandas.types import to_arrow_schema
0074                 from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
0075 
0076                 require_minimum_pyarrow_version()
0077                 to_arrow_schema(self.schema)
0078             except Exception as e:
0079 
0080                 if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
0081                     msg = (
0082                         "toPandas attempted Arrow optimization because "
0083                         "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
0084                         "failed by the reason below:\n  %s\n"
0085                         "Attempting non-optimization as "
0086                         "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
0087                         "true." % _exception_message(e))
0088                     warnings.warn(msg)
0089                     use_arrow = False
0090                 else:
0091                     msg = (
0092                         "toPandas attempted Arrow optimization because "
0093                         "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
0094                         "reached the error below and will not continue because automatic fallback "
0095                         "with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to "
0096                         "false.\n  %s" % _exception_message(e))
0097                     warnings.warn(msg)
0098                     raise
0099 
0100             # Try to use Arrow optimization when the schema is supported and the required version
0101             # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
0102             if use_arrow:
0103                 try:
0104                     from pyspark.sql.pandas.types import _check_series_localize_timestamps
0105                     import pyarrow
0106                     # Rename columns to avoid duplicated column names.
0107                     tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
0108                     batches = self.toDF(*tmp_column_names)._collect_as_arrow()
0109                     if len(batches) > 0:
0110                         table = pyarrow.Table.from_batches(batches)
0111                         # Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
0112                         # values, but we should use datetime.date to match the behavior with when
0113                         # Arrow optimization is disabled.
0114                         pdf = table.to_pandas(date_as_object=True)
0115                         # Rename back to the original column names.
0116                         pdf.columns = self.columns
0117                         for field in self.schema:
0118                             if isinstance(field.dataType, TimestampType):
0119                                 pdf[field.name] = \
0120                                     _check_series_localize_timestamps(pdf[field.name], timezone)
0121                         return pdf
0122                     else:
0123                         return pd.DataFrame.from_records([], columns=self.columns)
0124                 except Exception as e:
0125                     # We might have to allow fallback here as well but multiple Spark jobs can
0126                     # be executed. So, simply fail in this case for now.
0127                     msg = (
0128                         "toPandas attempted Arrow optimization because "
0129                         "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
0130                         "reached the error below and can not continue. Note that "
0131                         "'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an "
0132                         "effect on failures in the middle of "
0133                         "computation.\n  %s" % _exception_message(e))
0134                     warnings.warn(msg)
0135                     raise
0136 
0137         # Below is toPandas without Arrow optimization.
0138         pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
0139         column_counter = Counter(self.columns)
0140 
0141         dtype = [None] * len(self.schema)
0142         for fieldIdx, field in enumerate(self.schema):
0143             # For duplicate column name, we use `iloc` to access it.
0144             if column_counter[field.name] > 1:
0145                 pandas_col = pdf.iloc[:, fieldIdx]
0146             else:
0147                 pandas_col = pdf[field.name]
0148 
0149             pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType)
0150             # SPARK-21766: if an integer field is nullable and has null values, it can be
0151             # inferred by pandas as float column. Once we convert the column with NaN back
0152             # to integer type e.g., np.int16, we will hit exception. So we use the inferred
0153             # float type, not the corrected type from the schema in this case.
0154             if pandas_type is not None and \
0155                 not(isinstance(field.dataType, IntegralType) and field.nullable and
0156                     pandas_col.isnull().any()):
0157                 dtype[fieldIdx] = pandas_type
0158             # Ensure we fall back to nullable numpy types, even when whole column is null:
0159             if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any():
0160                 dtype[fieldIdx] = np.float64
0161             if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any():
0162                 dtype[fieldIdx] = np.object
0163 
0164         df = pd.DataFrame()
0165         for index, t in enumerate(dtype):
0166             column_name = self.schema[index].name
0167 
0168             # For duplicate column name, we use `iloc` to access it.
0169             if column_counter[column_name] > 1:
0170                 series = pdf.iloc[:, index]
0171             else:
0172                 series = pdf[column_name]
0173 
0174             if t is not None:
0175                 series = series.astype(t, copy=False)
0176 
0177             # `insert` API makes copy of data, we only do it for Series of duplicate column names.
0178             # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could
0179             # return a view or a copy depending by context.
0180             if column_counter[column_name] > 1:
0181                 df.insert(index, column_name, series, allow_duplicates=True)
0182             else:
0183                 df[column_name] = series
0184 
0185         pdf = df
0186 
0187         if timezone is None:
0188             return pdf
0189         else:
0190             from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
0191             for field in self.schema:
0192                 # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
0193                 if isinstance(field.dataType, TimestampType):
0194                     pdf[field.name] = \
0195                         _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
0196             return pdf
0197 
0198     @staticmethod
0199     def _to_corrected_pandas_type(dt):
0200         """
0201         When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type
0202         may be wrong. This method gets the corrected data type for Pandas if that type may be
0203         inferred incorrectly.
0204         """
0205         import numpy as np
0206         if type(dt) == ByteType:
0207             return np.int8
0208         elif type(dt) == ShortType:
0209             return np.int16
0210         elif type(dt) == IntegerType:
0211             return np.int32
0212         elif type(dt) == LongType:
0213             return np.int64
0214         elif type(dt) == FloatType:
0215             return np.float32
0216         elif type(dt) == DoubleType:
0217             return np.float64
0218         elif type(dt) == BooleanType:
0219             return np.bool
0220         elif type(dt) == TimestampType:
0221             return np.datetime64
0222         else:
0223             return None
0224 
0225     def _collect_as_arrow(self):
0226         """
0227         Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
0228         and available on driver and worker Python environments.
0229 
0230         .. note:: Experimental.
0231         """
0232         from pyspark.sql.dataframe import DataFrame
0233 
0234         assert isinstance(self, DataFrame)
0235 
0236         with SCCallSiteSync(self._sc):
0237             port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
0238 
0239         # Collect list of un-ordered batches where last element is a list of correct order indices
0240         try:
0241             results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
0242         finally:
0243             # Join serving thread and raise any exceptions from collectAsArrowToPython
0244             jsocket_auth_server.getResult()
0245 
0246         # Separate RecordBatches from batch order indices in results
0247         batches = results[:-1]
0248         batch_order = results[-1]
0249 
0250         # Re-order the batch list using the correct order
0251         return [batches[i] for i in batch_order]
0252 
0253 
0254 class SparkConversionMixin(object):
0255     """
0256     Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession`
0257     can use this class.
0258     """
0259     def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
0260         from pyspark.sql import SparkSession
0261 
0262         assert isinstance(self, SparkSession)
0263 
0264         from pyspark.sql.pandas.utils import require_minimum_pandas_version
0265         require_minimum_pandas_version()
0266 
0267         timezone = self._wrapped._conf.sessionLocalTimeZone()
0268 
0269         # If no schema supplied by user then get the names of columns only
0270         if schema is None:
0271             schema = [str(x) if not isinstance(x, basestring) else
0272                       (x.encode('utf-8') if not isinstance(x, str) else x)
0273                       for x in data.columns]
0274 
0275         if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
0276             try:
0277                 return self._create_from_pandas_with_arrow(data, schema, timezone)
0278             except Exception as e:
0279                 from pyspark.util import _exception_message
0280 
0281                 if self._wrapped._conf.arrowPySparkFallbackEnabled():
0282                     msg = (
0283                         "createDataFrame attempted Arrow optimization because "
0284                         "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
0285                         "failed by the reason below:\n  %s\n"
0286                         "Attempting non-optimization as "
0287                         "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
0288                         "true." % _exception_message(e))
0289                     warnings.warn(msg)
0290                 else:
0291                     msg = (
0292                         "createDataFrame attempted Arrow optimization because "
0293                         "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
0294                         "reached the error below and will not continue because automatic "
0295                         "fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' "
0296                         "has been set to false.\n  %s" % _exception_message(e))
0297                     warnings.warn(msg)
0298                     raise
0299         data = self._convert_from_pandas(data, schema, timezone)
0300         return self._create_dataframe(data, schema, samplingRatio, verifySchema)
0301 
0302     def _convert_from_pandas(self, pdf, schema, timezone):
0303         """
0304          Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
0305          :return list of records
0306         """
0307         from pyspark.sql import SparkSession
0308 
0309         assert isinstance(self, SparkSession)
0310 
0311         if timezone is not None:
0312             from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
0313             copied = False
0314             if isinstance(schema, StructType):
0315                 for field in schema:
0316                     # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
0317                     if isinstance(field.dataType, TimestampType):
0318                         s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
0319                         if s is not pdf[field.name]:
0320                             if not copied:
0321                                 # Copy once if the series is modified to prevent the original
0322                                 # Pandas DataFrame from being updated
0323                                 pdf = pdf.copy()
0324                                 copied = True
0325                             pdf[field.name] = s
0326             else:
0327                 for column, series in pdf.iteritems():
0328                     s = _check_series_convert_timestamps_tz_local(series, timezone)
0329                     if s is not series:
0330                         if not copied:
0331                             # Copy once if the series is modified to prevent the original
0332                             # Pandas DataFrame from being updated
0333                             pdf = pdf.copy()
0334                             copied = True
0335                         pdf[column] = s
0336 
0337         # Convert pandas.DataFrame to list of numpy records
0338         np_records = pdf.to_records(index=False)
0339 
0340         # Check if any columns need to be fixed for Spark to infer properly
0341         if len(np_records) > 0:
0342             record_dtype = self._get_numpy_record_dtype(np_records[0])
0343             if record_dtype is not None:
0344                 return [r.astype(record_dtype).tolist() for r in np_records]
0345 
0346         # Convert list of numpy records to python lists
0347         return [r.tolist() for r in np_records]
0348 
0349     def _get_numpy_record_dtype(self, rec):
0350         """
0351         Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
0352         the dtypes of fields in a record so they can be properly loaded into Spark.
0353         :param rec: a numpy record to check field dtypes
0354         :return corrected dtype for a numpy.record or None if no correction needed
0355         """
0356         import numpy as np
0357         cur_dtypes = rec.dtype
0358         col_names = cur_dtypes.names
0359         record_type_list = []
0360         has_rec_fix = False
0361         for i in xrange(len(cur_dtypes)):
0362             curr_type = cur_dtypes[i]
0363             # If type is a datetime64 timestamp, convert to microseconds
0364             # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,
0365             # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417
0366             if curr_type == np.dtype('datetime64[ns]'):
0367                 curr_type = 'datetime64[us]'
0368                 has_rec_fix = True
0369             record_type_list.append((str(col_names[i]), curr_type))
0370         return np.dtype(record_type_list) if has_rec_fix else None
0371 
0372     def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
0373         """
0374         Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
0375         to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
0376         data types will be used to coerce the data in Pandas to Arrow conversion.
0377         """
0378         from pyspark.sql import SparkSession
0379         from pyspark.sql.dataframe import DataFrame
0380 
0381         assert isinstance(self, SparkSession)
0382 
0383         from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
0384         from pyspark.sql.types import TimestampType
0385         from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
0386         from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
0387             require_minimum_pyarrow_version
0388 
0389         require_minimum_pandas_version()
0390         require_minimum_pyarrow_version()
0391 
0392         from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
0393         import pyarrow as pa
0394 
0395         # Create the Spark schema from list of names passed in with Arrow types
0396         if isinstance(schema, (list, tuple)):
0397             arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
0398             struct = StructType()
0399             for name, field in zip(schema, arrow_schema):
0400                 struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
0401             schema = struct
0402 
0403         # Determine arrow types to coerce data when creating batches
0404         if isinstance(schema, StructType):
0405             arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
0406         elif isinstance(schema, DataType):
0407             raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
0408         else:
0409             # Any timestamps must be coerced to be compatible with Spark
0410             arrow_types = [to_arrow_type(TimestampType())
0411                            if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
0412                            for t in pdf.dtypes]
0413 
0414         # Slice the DataFrame to be batched
0415         step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # round int up
0416         pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
0417 
0418         # Create list of Arrow (columns, type) for serializer dump_stream
0419         arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
0420                       for pdf_slice in pdf_slices]
0421 
0422         jsqlContext = self._wrapped._jsqlContext
0423 
0424         safecheck = self._wrapped._conf.arrowSafeTypeConversion()
0425         col_by_name = True  # col by name only applies to StructType columns, can't happen here
0426         ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
0427 
0428         def reader_func(temp_filename):
0429             return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
0430 
0431         def create_RDD_server():
0432             return self._jvm.ArrowRDDServer(jsqlContext)
0433 
0434         # Create Spark DataFrame from Arrow stream file, using one batch per partition
0435         jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
0436         jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
0437         df = DataFrame(jdf, self._wrapped)
0438         df._schema = schema
0439         return df
0440 
0441 
0442 def _test():
0443     import doctest
0444     from pyspark.sql import SparkSession
0445     import pyspark.sql.pandas.conversion
0446     globs = pyspark.sql.pandas.conversion.__dict__.copy()
0447     spark = SparkSession.builder\
0448         .master("local[4]")\
0449         .appName("sql.pandas.conversion tests")\
0450         .getOrCreate()
0451     globs['spark'] = spark
0452     (failure_count, test_count) = doctest.testmod(
0453         pyspark.sql.pandas.conversion, globs=globs,
0454         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0455     spark.stop()
0456     if failure_count:
0457         sys.exit(-1)
0458 
0459 
0460 if __name__ == "__main__":
0461     _test()