0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 import random
0020
0021 if sys.version >= '3':
0022 basestring = unicode = str
0023 long = int
0024 from functools import reduce
0025 from html import escape as html_escape
0026 else:
0027 from itertools import imap as map
0028 from cgi import escape as html_escape
0029
0030 import warnings
0031
0032 from pyspark import copy_func, since, _NoValue
0033 from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, \
0034 ignore_unicode_prefix
0035 from pyspark.serializers import BatchedSerializer, PickleSerializer, \
0036 UTF8Deserializer
0037 from pyspark.storagelevel import StorageLevel
0038 from pyspark.traceback_utils import SCCallSiteSync
0039 from pyspark.sql.types import _parse_datatype_json_string
0040 from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
0041 from pyspark.sql.readwriter import DataFrameWriter
0042 from pyspark.sql.streaming import DataStreamWriter
0043 from pyspark.sql.types import *
0044 from pyspark.sql.pandas.conversion import PandasConversionMixin
0045 from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
0046
0047 __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
0048
0049
0050 class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
0051 """A distributed collection of data grouped into named columns.
0052
0053 A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
0054 and can be created using various functions in :class:`SparkSession`::
0055
0056 people = spark.read.parquet("...")
0057
0058 Once created, it can be manipulated using the various domain-specific-language
0059 (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
0060
0061 To select a column from the :class:`DataFrame`, use the apply method::
0062
0063 ageCol = people.age
0064
0065 A more concrete example::
0066
0067 # To create DataFrame using SparkSession
0068 people = spark.read.parquet("...")
0069 department = spark.read.parquet("...")
0070
0071 people.filter(people.age > 30).join(department, people.deptId == department.id) \\
0072 .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
0073
0074 .. versionadded:: 1.3
0075 """
0076
0077 def __init__(self, jdf, sql_ctx):
0078 self._jdf = jdf
0079 self.sql_ctx = sql_ctx
0080 self._sc = sql_ctx and sql_ctx._sc
0081 self.is_cached = False
0082 self._schema = None
0083 self._lazy_rdd = None
0084
0085
0086 self._support_repr_html = False
0087
0088 @property
0089 @since(1.3)
0090 def rdd(self):
0091 """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
0092 """
0093 if self._lazy_rdd is None:
0094 jrdd = self._jdf.javaToPython()
0095 self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
0096 return self._lazy_rdd
0097
0098 @property
0099 @since("1.3.1")
0100 def na(self):
0101 """Returns a :class:`DataFrameNaFunctions` for handling missing values.
0102 """
0103 return DataFrameNaFunctions(self)
0104
0105 @property
0106 @since(1.4)
0107 def stat(self):
0108 """Returns a :class:`DataFrameStatFunctions` for statistic functions.
0109 """
0110 return DataFrameStatFunctions(self)
0111
0112 @ignore_unicode_prefix
0113 @since(1.3)
0114 def toJSON(self, use_unicode=True):
0115 """Converts a :class:`DataFrame` into a :class:`RDD` of string.
0116
0117 Each row is turned into a JSON document as one element in the returned RDD.
0118
0119 >>> df.toJSON().first()
0120 u'{"age":2,"name":"Alice"}'
0121 """
0122 rdd = self._jdf.toJSON()
0123 return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
0124
0125 @since(1.3)
0126 def registerTempTable(self, name):
0127 """Registers this DataFrame as a temporary table using the given name.
0128
0129 The lifetime of this temporary table is tied to the :class:`SparkSession`
0130 that was used to create this :class:`DataFrame`.
0131
0132 >>> df.registerTempTable("people")
0133 >>> df2 = spark.sql("select * from people")
0134 >>> sorted(df.collect()) == sorted(df2.collect())
0135 True
0136 >>> spark.catalog.dropTempView("people")
0137
0138 .. note:: Deprecated in 2.0, use createOrReplaceTempView instead.
0139 """
0140 warnings.warn(
0141 "Deprecated in 2.0, use createOrReplaceTempView instead.", DeprecationWarning)
0142 self._jdf.createOrReplaceTempView(name)
0143
0144 @since(2.0)
0145 def createTempView(self, name):
0146 """Creates a local temporary view with this :class:`DataFrame`.
0147
0148 The lifetime of this temporary table is tied to the :class:`SparkSession`
0149 that was used to create this :class:`DataFrame`.
0150 throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the
0151 catalog.
0152
0153 >>> df.createTempView("people")
0154 >>> df2 = spark.sql("select * from people")
0155 >>> sorted(df.collect()) == sorted(df2.collect())
0156 True
0157 >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL
0158 Traceback (most recent call last):
0159 ...
0160 AnalysisException: u"Temporary table 'people' already exists;"
0161 >>> spark.catalog.dropTempView("people")
0162
0163 """
0164 self._jdf.createTempView(name)
0165
0166 @since(2.0)
0167 def createOrReplaceTempView(self, name):
0168 """Creates or replaces a local temporary view with this :class:`DataFrame`.
0169
0170 The lifetime of this temporary table is tied to the :class:`SparkSession`
0171 that was used to create this :class:`DataFrame`.
0172
0173 >>> df.createOrReplaceTempView("people")
0174 >>> df2 = df.filter(df.age > 3)
0175 >>> df2.createOrReplaceTempView("people")
0176 >>> df3 = spark.sql("select * from people")
0177 >>> sorted(df3.collect()) == sorted(df2.collect())
0178 True
0179 >>> spark.catalog.dropTempView("people")
0180
0181 """
0182 self._jdf.createOrReplaceTempView(name)
0183
0184 @since(2.1)
0185 def createGlobalTempView(self, name):
0186 """Creates a global temporary view with this :class:`DataFrame`.
0187
0188 The lifetime of this temporary view is tied to this Spark application.
0189 throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the
0190 catalog.
0191
0192 >>> df.createGlobalTempView("people")
0193 >>> df2 = spark.sql("select * from global_temp.people")
0194 >>> sorted(df.collect()) == sorted(df2.collect())
0195 True
0196 >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL
0197 Traceback (most recent call last):
0198 ...
0199 AnalysisException: u"Temporary table 'people' already exists;"
0200 >>> spark.catalog.dropGlobalTempView("people")
0201
0202 """
0203 self._jdf.createGlobalTempView(name)
0204
0205 @since(2.2)
0206 def createOrReplaceGlobalTempView(self, name):
0207 """Creates or replaces a global temporary view using the given name.
0208
0209 The lifetime of this temporary view is tied to this Spark application.
0210
0211 >>> df.createOrReplaceGlobalTempView("people")
0212 >>> df2 = df.filter(df.age > 3)
0213 >>> df2.createOrReplaceGlobalTempView("people")
0214 >>> df3 = spark.sql("select * from global_temp.people")
0215 >>> sorted(df3.collect()) == sorted(df2.collect())
0216 True
0217 >>> spark.catalog.dropGlobalTempView("people")
0218
0219 """
0220 self._jdf.createOrReplaceGlobalTempView(name)
0221
0222 @property
0223 @since(1.4)
0224 def write(self):
0225 """
0226 Interface for saving the content of the non-streaming :class:`DataFrame` out into external
0227 storage.
0228
0229 :return: :class:`DataFrameWriter`
0230 """
0231 return DataFrameWriter(self)
0232
0233 @property
0234 @since(2.0)
0235 def writeStream(self):
0236 """
0237 Interface for saving the content of the streaming :class:`DataFrame` out into external
0238 storage.
0239
0240 .. note:: Evolving.
0241
0242 :return: :class:`DataStreamWriter`
0243 """
0244 return DataStreamWriter(self)
0245
0246 @property
0247 @since(1.3)
0248 def schema(self):
0249 """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`.
0250
0251 >>> df.schema
0252 StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
0253 """
0254 if self._schema is None:
0255 try:
0256 self._schema = _parse_datatype_json_string(self._jdf.schema().json())
0257 except AttributeError as e:
0258 raise Exception(
0259 "Unable to parse datatype from schema. %s" % e)
0260 return self._schema
0261
0262 @since(1.3)
0263 def printSchema(self):
0264 """Prints out the schema in the tree format.
0265
0266 >>> df.printSchema()
0267 root
0268 |-- age: integer (nullable = true)
0269 |-- name: string (nullable = true)
0270 <BLANKLINE>
0271 """
0272 print(self._jdf.schema().treeString())
0273
0274 @since(1.3)
0275 def explain(self, extended=None, mode=None):
0276 """Prints the (logical and physical) plans to the console for debugging purpose.
0277
0278 :param extended: boolean, default ``False``. If ``False``, prints only the physical plan.
0279 When this is a string without specifying the ``mode``, it works as the mode is
0280 specified.
0281 :param mode: specifies the expected output format of plans.
0282
0283 * ``simple``: Print only a physical plan.
0284 * ``extended``: Print both logical and physical plans.
0285 * ``codegen``: Print a physical plan and generated codes if they are available.
0286 * ``cost``: Print a logical plan and statistics if they are available.
0287 * ``formatted``: Split explain output into two sections: a physical plan outline \
0288 and node details.
0289
0290 >>> df.explain()
0291 == Physical Plan ==
0292 *(1) Scan ExistingRDD[age#0,name#1]
0293
0294 >>> df.explain(True)
0295 == Parsed Logical Plan ==
0296 ...
0297 == Analyzed Logical Plan ==
0298 ...
0299 == Optimized Logical Plan ==
0300 ...
0301 == Physical Plan ==
0302 ...
0303
0304 >>> df.explain(mode="formatted")
0305 == Physical Plan ==
0306 * Scan ExistingRDD (1)
0307 (1) Scan ExistingRDD [codegen id : 1]
0308 Output [2]: [age#0, name#1]
0309 ...
0310
0311 >>> df.explain("cost")
0312 == Optimized Logical Plan ==
0313 ...Statistics...
0314 ...
0315
0316 .. versionchanged:: 3.0.0
0317 Added optional argument `mode` to specify the expected output format of plans.
0318 """
0319
0320 if extended is not None and mode is not None:
0321 raise Exception("extended and mode should not be set together.")
0322
0323
0324 is_no_argument = extended is None and mode is None
0325
0326
0327
0328
0329 is_extended_case = isinstance(extended, bool) and mode is None
0330
0331
0332
0333 is_extended_as_mode = isinstance(extended, basestring) and mode is None
0334
0335
0336
0337 is_mode_case = extended is None and isinstance(mode, basestring)
0338
0339 if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case):
0340 argtypes = [
0341 str(type(arg)) for arg in [extended, mode] if arg is not None]
0342 raise TypeError(
0343 "extended (optional) and mode (optional) should be a string "
0344 "and bool; however, got [%s]." % ", ".join(argtypes))
0345
0346
0347 if is_no_argument:
0348 explain_mode = "simple"
0349 elif is_extended_case:
0350 explain_mode = "extended" if extended else "simple"
0351 elif is_mode_case:
0352 explain_mode = mode
0353 elif is_extended_as_mode:
0354 explain_mode = extended
0355
0356 print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode))
0357
0358 @since(2.4)
0359 def exceptAll(self, other):
0360 """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but
0361 not in another :class:`DataFrame` while preserving duplicates.
0362
0363 This is equivalent to `EXCEPT ALL` in SQL.
0364
0365 >>> df1 = spark.createDataFrame(
0366 ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"])
0367 >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"])
0368
0369 >>> df1.exceptAll(df2).show()
0370 +---+---+
0371 | C1| C2|
0372 +---+---+
0373 | a| 1|
0374 | a| 1|
0375 | a| 2|
0376 | c| 4|
0377 +---+---+
0378
0379 Also as standard in SQL, this function resolves columns by position (not by name).
0380 """
0381 return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
0382
0383 @since(1.3)
0384 def isLocal(self):
0385 """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
0386 (without any Spark executors).
0387 """
0388 return self._jdf.isLocal()
0389
0390 @property
0391 @since(2.0)
0392 def isStreaming(self):
0393 """Returns ``True`` if this :class:`Dataset` contains one or more sources that continuously
0394 return data as it arrives. A :class:`Dataset` that reads data from a streaming source
0395 must be executed as a :class:`StreamingQuery` using the :func:`start` method in
0396 :class:`DataStreamWriter`. Methods that return a single answer, (e.g., :func:`count` or
0397 :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming
0398 source present.
0399
0400 .. note:: Evolving
0401 """
0402 return self._jdf.isStreaming()
0403
0404 @since(1.3)
0405 def show(self, n=20, truncate=True, vertical=False):
0406 """Prints the first ``n`` rows to the console.
0407
0408 :param n: Number of rows to show.
0409 :param truncate: If set to ``True``, truncate strings longer than 20 chars by default.
0410 If set to a number greater than one, truncates long strings to length ``truncate``
0411 and align cells right.
0412 :param vertical: If set to ``True``, print output rows vertically (one line
0413 per column value).
0414
0415 >>> df
0416 DataFrame[age: int, name: string]
0417 >>> df.show()
0418 +---+-----+
0419 |age| name|
0420 +---+-----+
0421 | 2|Alice|
0422 | 5| Bob|
0423 +---+-----+
0424 >>> df.show(truncate=3)
0425 +---+----+
0426 |age|name|
0427 +---+----+
0428 | 2| Ali|
0429 | 5| Bob|
0430 +---+----+
0431 >>> df.show(vertical=True)
0432 -RECORD 0-----
0433 age | 2
0434 name | Alice
0435 -RECORD 1-----
0436 age | 5
0437 name | Bob
0438 """
0439 if isinstance(truncate, bool) and truncate:
0440 print(self._jdf.showString(n, 20, vertical))
0441 else:
0442 print(self._jdf.showString(n, int(truncate), vertical))
0443
0444 def __repr__(self):
0445 if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled():
0446 vertical = False
0447 return self._jdf.showString(
0448 self.sql_ctx._conf.replEagerEvalMaxNumRows(),
0449 self.sql_ctx._conf.replEagerEvalTruncate(), vertical)
0450 else:
0451 return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
0452
0453 def _repr_html_(self):
0454 """Returns a :class:`DataFrame` with html code when you enabled eager evaluation
0455 by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
0456 using support eager evaluation with HTML.
0457 """
0458 if not self._support_repr_html:
0459 self._support_repr_html = True
0460 if self.sql_ctx._conf.isReplEagerEvalEnabled():
0461 max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0)
0462 sock_info = self._jdf.getRowsToPython(
0463 max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate())
0464 rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
0465 head = rows[0]
0466 row_data = rows[1:]
0467 has_more_data = len(row_data) > max_num_rows
0468 row_data = row_data[:max_num_rows]
0469
0470 html = "<table border='1'>\n"
0471
0472 html += "<tr><th>%s</th></tr>\n" % "</th><th>".join(map(lambda x: html_escape(x), head))
0473
0474 for row in row_data:
0475 html += "<tr><td>%s</td></tr>\n" % "</td><td>".join(
0476 map(lambda x: html_escape(x), row))
0477 html += "</table>\n"
0478 if has_more_data:
0479 html += "only showing top %d %s\n" % (
0480 max_num_rows, "row" if max_num_rows == 1 else "rows")
0481 return html
0482 else:
0483 return None
0484
0485 @since(2.1)
0486 def checkpoint(self, eager=True):
0487 """Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
0488 logical plan of this :class:`DataFrame`, which is especially useful in iterative algorithms
0489 where the plan may grow exponentially. It will be saved to files inside the checkpoint
0490 directory set with :meth:`SparkContext.setCheckpointDir`.
0491
0492 :param eager: Whether to checkpoint this :class:`DataFrame` immediately
0493
0494 .. note:: Experimental
0495 """
0496 jdf = self._jdf.checkpoint(eager)
0497 return DataFrame(jdf, self.sql_ctx)
0498
0499 @since(2.3)
0500 def localCheckpoint(self, eager=True):
0501 """Returns a locally checkpointed version of this Dataset. Checkpointing can be used to
0502 truncate the logical plan of this :class:`DataFrame`, which is especially useful in
0503 iterative algorithms where the plan may grow exponentially. Local checkpoints are
0504 stored in the executors using the caching subsystem and therefore they are not reliable.
0505
0506 :param eager: Whether to checkpoint this :class:`DataFrame` immediately
0507
0508 .. note:: Experimental
0509 """
0510 jdf = self._jdf.localCheckpoint(eager)
0511 return DataFrame(jdf, self.sql_ctx)
0512
0513 @since(2.1)
0514 def withWatermark(self, eventTime, delayThreshold):
0515 """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point
0516 in time before which we assume no more late data is going to arrive.
0517
0518 Spark will use this watermark for several purposes:
0519 - To know when a given time window aggregation can be finalized and thus can be emitted
0520 when using output modes that do not allow updates.
0521
0522 - To minimize the amount of state that we need to keep for on-going aggregations.
0523
0524 The current watermark is computed by looking at the `MAX(eventTime)` seen across
0525 all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost
0526 of coordinating this value across partitions, the actual watermark used is only guaranteed
0527 to be at least `delayThreshold` behind the actual event time. In some cases we may still
0528 process records that arrive more than `delayThreshold` late.
0529
0530 :param eventTime: the name of the column that contains the event time of the row.
0531 :param delayThreshold: the minimum delay to wait to data to arrive late, relative to the
0532 latest record that has been processed in the form of an interval
0533 (e.g. "1 minute" or "5 hours").
0534
0535 .. note:: Evolving
0536
0537 >>> sdf.select('name', sdf.time.cast('timestamp')).withWatermark('time', '10 minutes')
0538 DataFrame[name: string, time: timestamp]
0539 """
0540 if not eventTime or type(eventTime) is not str:
0541 raise TypeError("eventTime should be provided as a string")
0542 if not delayThreshold or type(delayThreshold) is not str:
0543 raise TypeError("delayThreshold should be provided as a string interval")
0544 jdf = self._jdf.withWatermark(eventTime, delayThreshold)
0545 return DataFrame(jdf, self.sql_ctx)
0546
0547 @since(2.2)
0548 def hint(self, name, *parameters):
0549 """Specifies some hint on the current :class:`DataFrame`.
0550
0551 :param name: A name of the hint.
0552 :param parameters: Optional parameters.
0553 :return: :class:`DataFrame`
0554
0555 >>> df.join(df2.hint("broadcast"), "name").show()
0556 +----+---+------+
0557 |name|age|height|
0558 +----+---+------+
0559 | Bob| 5| 85|
0560 +----+---+------+
0561 """
0562 if len(parameters) == 1 and isinstance(parameters[0], list):
0563 parameters = parameters[0]
0564
0565 if not isinstance(name, str):
0566 raise TypeError("name should be provided as str, got {0}".format(type(name)))
0567
0568 allowed_types = (basestring, list, float, int)
0569 for p in parameters:
0570 if not isinstance(p, allowed_types):
0571 raise TypeError(
0572 "all parameters should be in {0}, got {1} of type {2}".format(
0573 allowed_types, p, type(p)))
0574
0575 jdf = self._jdf.hint(name, self._jseq(parameters))
0576 return DataFrame(jdf, self.sql_ctx)
0577
0578 @since(1.3)
0579 def count(self):
0580 """Returns the number of rows in this :class:`DataFrame`.
0581
0582 >>> df.count()
0583 2
0584 """
0585 return int(self._jdf.count())
0586
0587 @ignore_unicode_prefix
0588 @since(1.3)
0589 def collect(self):
0590 """Returns all the records as a list of :class:`Row`.
0591
0592 >>> df.collect()
0593 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
0594 """
0595 with SCCallSiteSync(self._sc) as css:
0596 sock_info = self._jdf.collectToPython()
0597 return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
0598
0599 @ignore_unicode_prefix
0600 @since(2.0)
0601 def toLocalIterator(self, prefetchPartitions=False):
0602 """
0603 Returns an iterator that contains all of the rows in this :class:`DataFrame`.
0604 The iterator will consume as much memory as the largest partition in this
0605 :class:`DataFrame`. With prefetch it may consume up to the memory of the 2 largest
0606 partitions.
0607
0608 :param prefetchPartitions: If Spark should pre-fetch the next partition
0609 before it is needed.
0610
0611 >>> list(df.toLocalIterator())
0612 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
0613 """
0614 with SCCallSiteSync(self._sc) as css:
0615 sock_info = self._jdf.toPythonIterator(prefetchPartitions)
0616 return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
0617
0618 @ignore_unicode_prefix
0619 @since(1.3)
0620 def limit(self, num):
0621 """Limits the result count to the number specified.
0622
0623 >>> df.limit(1).collect()
0624 [Row(age=2, name=u'Alice')]
0625 >>> df.limit(0).collect()
0626 []
0627 """
0628 jdf = self._jdf.limit(num)
0629 return DataFrame(jdf, self.sql_ctx)
0630
0631 @ignore_unicode_prefix
0632 @since(1.3)
0633 def take(self, num):
0634 """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
0635
0636 >>> df.take(2)
0637 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
0638 """
0639 return self.limit(num).collect()
0640
0641 @ignore_unicode_prefix
0642 @since(3.0)
0643 def tail(self, num):
0644 """
0645 Returns the last ``num`` rows as a :class:`list` of :class:`Row`.
0646
0647 Running tail requires moving data into the application's driver process, and doing so with
0648 a very large ``num`` can crash the driver process with OutOfMemoryError.
0649
0650 >>> df.tail(1)
0651 [Row(age=5, name=u'Bob')]
0652 """
0653 with SCCallSiteSync(self._sc):
0654 sock_info = self._jdf.tailToPython(num)
0655 return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
0656
0657 @since(1.3)
0658 def foreach(self, f):
0659 """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`.
0660
0661 This is a shorthand for ``df.rdd.foreach()``.
0662
0663 >>> def f(person):
0664 ... print(person.name)
0665 >>> df.foreach(f)
0666 """
0667 self.rdd.foreach(f)
0668
0669 @since(1.3)
0670 def foreachPartition(self, f):
0671 """Applies the ``f`` function to each partition of this :class:`DataFrame`.
0672
0673 This a shorthand for ``df.rdd.foreachPartition()``.
0674
0675 >>> def f(people):
0676 ... for person in people:
0677 ... print(person.name)
0678 >>> df.foreachPartition(f)
0679 """
0680 self.rdd.foreachPartition(f)
0681
0682 @since(1.3)
0683 def cache(self):
0684 """Persists the :class:`DataFrame` with the default storage level (`MEMORY_AND_DISK`).
0685
0686 .. note:: The default storage level has changed to `MEMORY_AND_DISK` to match Scala in 2.0.
0687 """
0688 self.is_cached = True
0689 self._jdf.cache()
0690 return self
0691
0692 @since(1.3)
0693 def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK):
0694 """Sets the storage level to persist the contents of the :class:`DataFrame` across
0695 operations after the first time it is computed. This can only be used to assign
0696 a new storage level if the :class:`DataFrame` does not have a storage level set yet.
0697 If no storage level is specified defaults to (`MEMORY_AND_DISK`).
0698
0699 .. note:: The default storage level has changed to `MEMORY_AND_DISK` to match Scala in 2.0.
0700 """
0701 self.is_cached = True
0702 javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
0703 self._jdf.persist(javaStorageLevel)
0704 return self
0705
0706 @property
0707 @since(2.1)
0708 def storageLevel(self):
0709 """Get the :class:`DataFrame`'s current storage level.
0710
0711 >>> df.storageLevel
0712 StorageLevel(False, False, False, False, 1)
0713 >>> df.cache().storageLevel
0714 StorageLevel(True, True, False, True, 1)
0715 >>> df2.persist(StorageLevel.DISK_ONLY_2).storageLevel
0716 StorageLevel(True, False, False, False, 2)
0717 """
0718 java_storage_level = self._jdf.storageLevel()
0719 storage_level = StorageLevel(java_storage_level.useDisk(),
0720 java_storage_level.useMemory(),
0721 java_storage_level.useOffHeap(),
0722 java_storage_level.deserialized(),
0723 java_storage_level.replication())
0724 return storage_level
0725
0726 @since(1.3)
0727 def unpersist(self, blocking=False):
0728 """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from
0729 memory and disk.
0730
0731 .. note:: `blocking` default has changed to ``False`` to match Scala in 2.0.
0732 """
0733 self.is_cached = False
0734 self._jdf.unpersist(blocking)
0735 return self
0736
0737 @since(1.4)
0738 def coalesce(self, numPartitions):
0739 """
0740 Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.
0741
0742 :param numPartitions: int, to specify the target number of partitions
0743
0744 Similar to coalesce defined on an :class:`RDD`, this operation results in a
0745 narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
0746 there will not be a shuffle, instead each of the 100 new partitions will
0747 claim 10 of the current partitions. If a larger number of partitions is requested,
0748 it will stay at the current number of partitions.
0749
0750 However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
0751 this may result in your computation taking place on fewer nodes than
0752 you like (e.g. one node in the case of numPartitions = 1). To avoid this,
0753 you can call repartition(). This will add a shuffle step, but means the
0754 current upstream partitions will be executed in parallel (per whatever
0755 the current partitioning is).
0756
0757 >>> df.coalesce(1).rdd.getNumPartitions()
0758 1
0759 """
0760 return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
0761
0762 @since(1.3)
0763 def repartition(self, numPartitions, *cols):
0764 """
0765 Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
0766 resulting :class:`DataFrame` is hash partitioned.
0767
0768 :param numPartitions:
0769 can be an int to specify the target number of partitions or a Column.
0770 If it is a Column, it will be used as the first partitioning column. If not specified,
0771 the default number of partitions is used.
0772
0773 .. versionchanged:: 1.6
0774 Added optional arguments to specify the partitioning columns. Also made numPartitions
0775 optional if partitioning columns are specified.
0776
0777 >>> df.repartition(10).rdd.getNumPartitions()
0778 10
0779 >>> data = df.union(df).repartition("age")
0780 >>> data.show()
0781 +---+-----+
0782 |age| name|
0783 +---+-----+
0784 | 5| Bob|
0785 | 5| Bob|
0786 | 2|Alice|
0787 | 2|Alice|
0788 +---+-----+
0789 >>> data = data.repartition(7, "age")
0790 >>> data.show()
0791 +---+-----+
0792 |age| name|
0793 +---+-----+
0794 | 2|Alice|
0795 | 5| Bob|
0796 | 2|Alice|
0797 | 5| Bob|
0798 +---+-----+
0799 >>> data.rdd.getNumPartitions()
0800 7
0801 >>> data = data.repartition("name", "age")
0802 >>> data.show()
0803 +---+-----+
0804 |age| name|
0805 +---+-----+
0806 | 5| Bob|
0807 | 5| Bob|
0808 | 2|Alice|
0809 | 2|Alice|
0810 +---+-----+
0811 """
0812 if isinstance(numPartitions, int):
0813 if len(cols) == 0:
0814 return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
0815 else:
0816 return DataFrame(
0817 self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx)
0818 elif isinstance(numPartitions, (basestring, Column)):
0819 cols = (numPartitions, ) + cols
0820 return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx)
0821 else:
0822 raise TypeError("numPartitions should be an int or Column")
0823
0824 @since("2.4.0")
0825 def repartitionByRange(self, numPartitions, *cols):
0826 """
0827 Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
0828 resulting :class:`DataFrame` is range partitioned.
0829
0830 :param numPartitions:
0831 can be an int to specify the target number of partitions or a Column.
0832 If it is a Column, it will be used as the first partitioning column. If not specified,
0833 the default number of partitions is used.
0834
0835 At least one partition-by expression must be specified.
0836 When no explicit sort order is specified, "ascending nulls first" is assumed.
0837
0838 Note that due to performance reasons this method uses sampling to estimate the ranges.
0839 Hence, the output may not be consistent, since sampling can return different values.
0840 The sample size can be controlled by the config
0841 `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
0842
0843 >>> df.repartitionByRange(2, "age").rdd.getNumPartitions()
0844 2
0845 >>> df.show()
0846 +---+-----+
0847 |age| name|
0848 +---+-----+
0849 | 2|Alice|
0850 | 5| Bob|
0851 +---+-----+
0852 >>> df.repartitionByRange(1, "age").rdd.getNumPartitions()
0853 1
0854 >>> data = df.repartitionByRange("age")
0855 >>> df.show()
0856 +---+-----+
0857 |age| name|
0858 +---+-----+
0859 | 2|Alice|
0860 | 5| Bob|
0861 +---+-----+
0862 """
0863 if isinstance(numPartitions, int):
0864 if len(cols) == 0:
0865 return ValueError("At least one partition-by expression must be specified.")
0866 else:
0867 return DataFrame(
0868 self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx)
0869 elif isinstance(numPartitions, (basestring, Column)):
0870 cols = (numPartitions,) + cols
0871 return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx)
0872 else:
0873 raise TypeError("numPartitions should be an int, string or Column")
0874
0875 @since(1.3)
0876 def distinct(self):
0877 """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
0878
0879 >>> df.distinct().count()
0880 2
0881 """
0882 return DataFrame(self._jdf.distinct(), self.sql_ctx)
0883
0884 @since(1.3)
0885 def sample(self, withReplacement=None, fraction=None, seed=None):
0886 """Returns a sampled subset of this :class:`DataFrame`.
0887
0888 :param withReplacement: Sample with replacement or not (default ``False``).
0889 :param fraction: Fraction of rows to generate, range [0.0, 1.0].
0890 :param seed: Seed for sampling (default a random seed).
0891
0892 .. note:: This is not guaranteed to provide exactly the fraction specified of the total
0893 count of the given :class:`DataFrame`.
0894
0895 .. note:: `fraction` is required and, `withReplacement` and `seed` are optional.
0896
0897 >>> df = spark.range(10)
0898 >>> df.sample(0.5, 3).count()
0899 7
0900 >>> df.sample(fraction=0.5, seed=3).count()
0901 7
0902 >>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
0903 1
0904 >>> df.sample(1.0).count()
0905 10
0906 >>> df.sample(fraction=1.0).count()
0907 10
0908 >>> df.sample(False, fraction=1.0).count()
0909 10
0910 """
0911
0912
0913
0914
0915
0916 is_withReplacement_set = \
0917 type(withReplacement) == bool and isinstance(fraction, float)
0918
0919
0920
0921 is_withReplacement_omitted_kwargs = \
0922 withReplacement is None and isinstance(fraction, float)
0923
0924
0925
0926 is_withReplacement_omitted_args = isinstance(withReplacement, float)
0927
0928 if not (is_withReplacement_set
0929 or is_withReplacement_omitted_kwargs
0930 or is_withReplacement_omitted_args):
0931 argtypes = [
0932 str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None]
0933 raise TypeError(
0934 "withReplacement (optional), fraction (required) and seed (optional)"
0935 " should be a bool, float and number; however, "
0936 "got [%s]." % ", ".join(argtypes))
0937
0938 if is_withReplacement_omitted_args:
0939 if fraction is not None:
0940 seed = fraction
0941 fraction = withReplacement
0942 withReplacement = None
0943
0944 seed = long(seed) if seed is not None else None
0945 args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
0946 jdf = self._jdf.sample(*args)
0947 return DataFrame(jdf, self.sql_ctx)
0948
0949 @since(1.5)
0950 def sampleBy(self, col, fractions, seed=None):
0951 """
0952 Returns a stratified sample without replacement based on the
0953 fraction given on each stratum.
0954
0955 :param col: column that defines strata
0956 :param fractions:
0957 sampling fraction for each stratum. If a stratum is not
0958 specified, we treat its fraction as zero.
0959 :param seed: random seed
0960 :return: a new :class:`DataFrame` that represents the stratified sample
0961
0962 >>> from pyspark.sql.functions import col
0963 >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
0964 >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
0965 >>> sampled.groupBy("key").count().orderBy("key").show()
0966 +---+-----+
0967 |key|count|
0968 +---+-----+
0969 | 0| 3|
0970 | 1| 6|
0971 +---+-----+
0972 >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()
0973 33
0974
0975 .. versionchanged:: 3.0
0976 Added sampling by a column of :class:`Column`
0977 """
0978 if isinstance(col, basestring):
0979 col = Column(col)
0980 elif not isinstance(col, Column):
0981 raise ValueError("col must be a string or a column, but got %r" % type(col))
0982 if not isinstance(fractions, dict):
0983 raise ValueError("fractions must be a dict but got %r" % type(fractions))
0984 for k, v in fractions.items():
0985 if not isinstance(k, (float, int, long, basestring)):
0986 raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
0987 fractions[k] = float(v)
0988 col = col._jc
0989 seed = seed if seed is not None else random.randint(0, sys.maxsize)
0990 return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
0991
0992 @since(1.4)
0993 def randomSplit(self, weights, seed=None):
0994 """Randomly splits this :class:`DataFrame` with the provided weights.
0995
0996 :param weights: list of doubles as weights with which to split the :class:`DataFrame`.
0997 Weights will be normalized if they don't sum up to 1.0.
0998 :param seed: The seed for sampling.
0999
1000 >>> splits = df4.randomSplit([1.0, 2.0], 24)
1001 >>> splits[0].count()
1002 2
1003
1004 >>> splits[1].count()
1005 2
1006 """
1007 for w in weights:
1008 if w < 0.0:
1009 raise ValueError("Weights must be positive. Found weight value: %s" % w)
1010 seed = seed if seed is not None else random.randint(0, sys.maxsize)
1011 rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed))
1012 return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
1013
1014 @property
1015 @since(1.3)
1016 def dtypes(self):
1017 """Returns all column names and their data types as a list.
1018
1019 >>> df.dtypes
1020 [('age', 'int'), ('name', 'string')]
1021 """
1022 return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
1023
1024 @property
1025 @since(1.3)
1026 def columns(self):
1027 """Returns all column names as a list.
1028
1029 >>> df.columns
1030 ['age', 'name']
1031 """
1032 return [f.name for f in self.schema.fields]
1033
1034 @since(2.3)
1035 def colRegex(self, colName):
1036 """
1037 Selects column based on the column name specified as a regex and returns it
1038 as :class:`Column`.
1039
1040 :param colName: string, column name specified as a regex.
1041
1042 >>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"])
1043 >>> df.select(df.colRegex("`(Col1)?+.+`")).show()
1044 +----+
1045 |Col2|
1046 +----+
1047 | 1|
1048 | 2|
1049 | 3|
1050 +----+
1051 """
1052 if not isinstance(colName, basestring):
1053 raise ValueError("colName should be provided as string")
1054 jc = self._jdf.colRegex(colName)
1055 return Column(jc)
1056
1057 @ignore_unicode_prefix
1058 @since(1.3)
1059 def alias(self, alias):
1060 """Returns a new :class:`DataFrame` with an alias set.
1061
1062 :param alias: string, an alias name to be set for the :class:`DataFrame`.
1063
1064 >>> from pyspark.sql.functions import *
1065 >>> df_as1 = df.alias("df_as1")
1066 >>> df_as2 = df.alias("df_as2")
1067 >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
1068 >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age") \
1069 .sort(desc("df_as1.name")).collect()
1070 [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)]
1071 """
1072 assert isinstance(alias, basestring), "alias should be a string"
1073 return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
1074
1075 @ignore_unicode_prefix
1076 @since(2.1)
1077 def crossJoin(self, other):
1078 """Returns the cartesian product with another :class:`DataFrame`.
1079
1080 :param other: Right side of the cartesian product.
1081
1082 >>> df.select("age", "name").collect()
1083 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
1084 >>> df2.select("name", "height").collect()
1085 [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85)]
1086 >>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect()
1087 [Row(age=2, name=u'Alice', height=80), Row(age=2, name=u'Alice', height=85),
1088 Row(age=5, name=u'Bob', height=80), Row(age=5, name=u'Bob', height=85)]
1089 """
1090
1091 jdf = self._jdf.crossJoin(other._jdf)
1092 return DataFrame(jdf, self.sql_ctx)
1093
1094 @ignore_unicode_prefix
1095 @since(1.3)
1096 def join(self, other, on=None, how=None):
1097 """Joins with another :class:`DataFrame`, using the given join expression.
1098
1099 :param other: Right side of the join
1100 :param on: a string for the join column name, a list of column names,
1101 a join expression (Column), or a list of Columns.
1102 If `on` is a string or a list of strings indicating the name of the join column(s),
1103 the column(s) must exist on both sides, and this performs an equi-join.
1104 :param how: str, default ``inner``. Must be one of: ``inner``, ``cross``, ``outer``,
1105 ``full``, ``fullouter``, ``full_outer``, ``left``, ``leftouter``, ``left_outer``,
1106 ``right``, ``rightouter``, ``right_outer``, ``semi``, ``leftsemi``, ``left_semi``,
1107 ``anti``, ``leftanti`` and ``left_anti``.
1108
1109 The following performs a full outer join between ``df1`` and ``df2``.
1110 >>> from pyspark.sql.functions import desc
1111 >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height) \
1112 .sort(desc("name")).collect()
1113 [Row(name=u'Bob', height=85), Row(name=u'Alice', height=None), Row(name=None, height=80)]
1114
1115 >>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).collect()
1116 [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
1117
1118 >>> cond = [df.name == df3.name, df.age == df3.age]
1119 >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
1120 [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
1121
1122 >>> df.join(df2, 'name').select(df.name, df2.height).collect()
1123 [Row(name=u'Bob', height=85)]
1124
1125 >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
1126 [Row(name=u'Bob', age=5)]
1127 """
1128
1129 if on is not None and not isinstance(on, list):
1130 on = [on]
1131
1132 if on is not None:
1133 if isinstance(on[0], basestring):
1134 on = self._jseq(on)
1135 else:
1136 assert isinstance(on[0], Column), "on should be Column or list of Column"
1137 on = reduce(lambda x, y: x.__and__(y), on)
1138 on = on._jc
1139
1140 if on is None and how is None:
1141 jdf = self._jdf.join(other._jdf)
1142 else:
1143 if how is None:
1144 how = "inner"
1145 if on is None:
1146 on = self._jseq([])
1147 assert isinstance(how, basestring), "how should be basestring"
1148 jdf = self._jdf.join(other._jdf, on, how)
1149 return DataFrame(jdf, self.sql_ctx)
1150
1151 @since(1.6)
1152 def sortWithinPartitions(self, *cols, **kwargs):
1153 """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).
1154
1155 :param cols: list of :class:`Column` or column names to sort by.
1156 :param ascending: boolean or list of boolean (default ``True``).
1157 Sort ascending vs. descending. Specify list for multiple sort orders.
1158 If a list is specified, length of the list must equal length of the `cols`.
1159
1160 >>> df.sortWithinPartitions("age", ascending=False).show()
1161 +---+-----+
1162 |age| name|
1163 +---+-----+
1164 | 2|Alice|
1165 | 5| Bob|
1166 +---+-----+
1167 """
1168 jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
1169 return DataFrame(jdf, self.sql_ctx)
1170
1171 @ignore_unicode_prefix
1172 @since(1.3)
1173 def sort(self, *cols, **kwargs):
1174 """Returns a new :class:`DataFrame` sorted by the specified column(s).
1175
1176 :param cols: list of :class:`Column` or column names to sort by.
1177 :param ascending: boolean or list of boolean (default ``True``).
1178 Sort ascending vs. descending. Specify list for multiple sort orders.
1179 If a list is specified, length of the list must equal length of the `cols`.
1180
1181 >>> df.sort(df.age.desc()).collect()
1182 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
1183 >>> df.sort("age", ascending=False).collect()
1184 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
1185 >>> df.orderBy(df.age.desc()).collect()
1186 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
1187 >>> from pyspark.sql.functions import *
1188 >>> df.sort(asc("age")).collect()
1189 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
1190 >>> df.orderBy(desc("age"), "name").collect()
1191 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
1192 >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
1193 [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
1194 """
1195 jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
1196 return DataFrame(jdf, self.sql_ctx)
1197
1198 orderBy = sort
1199
1200 def _jseq(self, cols, converter=None):
1201 """Return a JVM Seq of Columns from a list of Column or names"""
1202 return _to_seq(self.sql_ctx._sc, cols, converter)
1203
1204 def _jmap(self, jm):
1205 """Return a JVM Scala Map from a dict"""
1206 return _to_scala_map(self.sql_ctx._sc, jm)
1207
1208 def _jcols(self, *cols):
1209 """Return a JVM Seq of Columns from a list of Column or column names
1210
1211 If `cols` has only one list in it, cols[0] will be used as the list.
1212 """
1213 if len(cols) == 1 and isinstance(cols[0], list):
1214 cols = cols[0]
1215 return self._jseq(cols, _to_java_column)
1216
1217 def _sort_cols(self, cols, kwargs):
1218 """ Return a JVM Seq of Columns that describes the sort order
1219 """
1220 if not cols:
1221 raise ValueError("should sort by at least one column")
1222 if len(cols) == 1 and isinstance(cols[0], list):
1223 cols = cols[0]
1224 jcols = [_to_java_column(c) for c in cols]
1225 ascending = kwargs.get('ascending', True)
1226 if isinstance(ascending, (bool, int)):
1227 if not ascending:
1228 jcols = [jc.desc() for jc in jcols]
1229 elif isinstance(ascending, list):
1230 jcols = [jc if asc else jc.desc()
1231 for asc, jc in zip(ascending, jcols)]
1232 else:
1233 raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending))
1234 return self._jseq(jcols)
1235
1236 @since("1.3.1")
1237 def describe(self, *cols):
1238 """Computes basic statistics for numeric and string columns.
1239
1240 This include count, mean, stddev, min, and max. If no columns are
1241 given, this function computes statistics for all numerical or string columns.
1242
1243 .. note:: This function is meant for exploratory data analysis, as we make no
1244 guarantee about the backward compatibility of the schema of the resulting
1245 :class:`DataFrame`.
1246
1247 >>> df.describe(['age']).show()
1248 +-------+------------------+
1249 |summary| age|
1250 +-------+------------------+
1251 | count| 2|
1252 | mean| 3.5|
1253 | stddev|2.1213203435596424|
1254 | min| 2|
1255 | max| 5|
1256 +-------+------------------+
1257 >>> df.describe().show()
1258 +-------+------------------+-----+
1259 |summary| age| name|
1260 +-------+------------------+-----+
1261 | count| 2| 2|
1262 | mean| 3.5| null|
1263 | stddev|2.1213203435596424| null|
1264 | min| 2|Alice|
1265 | max| 5| Bob|
1266 +-------+------------------+-----+
1267
1268 Use summary for expanded statistics and control over which statistics to compute.
1269 """
1270 if len(cols) == 1 and isinstance(cols[0], list):
1271 cols = cols[0]
1272 jdf = self._jdf.describe(self._jseq(cols))
1273 return DataFrame(jdf, self.sql_ctx)
1274
1275 @since("2.3.0")
1276 def summary(self, *statistics):
1277 """Computes specified statistics for numeric and string columns. Available statistics are:
1278 - count
1279 - mean
1280 - stddev
1281 - min
1282 - max
1283 - arbitrary approximate percentiles specified as a percentage (eg, 75%)
1284
1285 If no statistics are given, this function computes count, mean, stddev, min,
1286 approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
1287
1288 .. note:: This function is meant for exploratory data analysis, as we make no
1289 guarantee about the backward compatibility of the schema of the resulting
1290 :class:`DataFrame`.
1291
1292 >>> df.summary().show()
1293 +-------+------------------+-----+
1294 |summary| age| name|
1295 +-------+------------------+-----+
1296 | count| 2| 2|
1297 | mean| 3.5| null|
1298 | stddev|2.1213203435596424| null|
1299 | min| 2|Alice|
1300 | 25%| 2| null|
1301 | 50%| 2| null|
1302 | 75%| 5| null|
1303 | max| 5| Bob|
1304 +-------+------------------+-----+
1305
1306 >>> df.summary("count", "min", "25%", "75%", "max").show()
1307 +-------+---+-----+
1308 |summary|age| name|
1309 +-------+---+-----+
1310 | count| 2| 2|
1311 | min| 2|Alice|
1312 | 25%| 2| null|
1313 | 75%| 5| null|
1314 | max| 5| Bob|
1315 +-------+---+-----+
1316
1317 To do a summary for specific columns first select them:
1318
1319 >>> df.select("age", "name").summary("count").show()
1320 +-------+---+----+
1321 |summary|age|name|
1322 +-------+---+----+
1323 | count| 2| 2|
1324 +-------+---+----+
1325
1326 See also describe for basic statistics.
1327 """
1328 if len(statistics) == 1 and isinstance(statistics[0], list):
1329 statistics = statistics[0]
1330 jdf = self._jdf.summary(self._jseq(statistics))
1331 return DataFrame(jdf, self.sql_ctx)
1332
1333 @ignore_unicode_prefix
1334 @since(1.3)
1335 def head(self, n=None):
1336 """Returns the first ``n`` rows.
1337
1338 .. note:: This method should only be used if the resulting array is expected
1339 to be small, as all the data is loaded into the driver's memory.
1340
1341 :param n: int, default 1. Number of rows to return.
1342 :return: If n is greater than 1, return a list of :class:`Row`.
1343 If n is 1, return a single Row.
1344
1345 >>> df.head()
1346 Row(age=2, name=u'Alice')
1347 >>> df.head(1)
1348 [Row(age=2, name=u'Alice')]
1349 """
1350 if n is None:
1351 rs = self.head(1)
1352 return rs[0] if rs else None
1353 return self.take(n)
1354
1355 @ignore_unicode_prefix
1356 @since(1.3)
1357 def first(self):
1358 """Returns the first row as a :class:`Row`.
1359
1360 >>> df.first()
1361 Row(age=2, name=u'Alice')
1362 """
1363 return self.head()
1364
1365 @ignore_unicode_prefix
1366 @since(1.3)
1367 def __getitem__(self, item):
1368 """Returns the column as a :class:`Column`.
1369
1370 >>> df.select(df['age']).collect()
1371 [Row(age=2), Row(age=5)]
1372 >>> df[ ["name", "age"]].collect()
1373 [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
1374 >>> df[ df.age > 3 ].collect()
1375 [Row(age=5, name=u'Bob')]
1376 >>> df[df[0] > 3].collect()
1377 [Row(age=5, name=u'Bob')]
1378 """
1379 if isinstance(item, basestring):
1380 jc = self._jdf.apply(item)
1381 return Column(jc)
1382 elif isinstance(item, Column):
1383 return self.filter(item)
1384 elif isinstance(item, (list, tuple)):
1385 return self.select(*item)
1386 elif isinstance(item, int):
1387 jc = self._jdf.apply(self.columns[item])
1388 return Column(jc)
1389 else:
1390 raise TypeError("unexpected item type: %s" % type(item))
1391
1392 @since(1.3)
1393 def __getattr__(self, name):
1394 """Returns the :class:`Column` denoted by ``name``.
1395
1396 >>> df.select(df.age).collect()
1397 [Row(age=2), Row(age=5)]
1398 """
1399 if name not in self.columns:
1400 raise AttributeError(
1401 "'%s' object has no attribute '%s'" % (self.__class__.__name__, name))
1402 jc = self._jdf.apply(name)
1403 return Column(jc)
1404
1405 @ignore_unicode_prefix
1406 @since(1.3)
1407 def select(self, *cols):
1408 """Projects a set of expressions and returns a new :class:`DataFrame`.
1409
1410 :param cols: list of column names (string) or expressions (:class:`Column`).
1411 If one of the column names is '*', that column is expanded to include all columns
1412 in the current :class:`DataFrame`.
1413
1414 >>> df.select('*').collect()
1415 [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
1416 >>> df.select('name', 'age').collect()
1417 [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
1418 >>> df.select(df.name, (df.age + 10).alias('age')).collect()
1419 [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
1420 """
1421 jdf = self._jdf.select(self._jcols(*cols))
1422 return DataFrame(jdf, self.sql_ctx)
1423
1424 @since(1.3)
1425 def selectExpr(self, *expr):
1426 """Projects a set of SQL expressions and returns a new :class:`DataFrame`.
1427
1428 This is a variant of :func:`select` that accepts SQL expressions.
1429
1430 >>> df.selectExpr("age * 2", "abs(age)").collect()
1431 [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
1432 """
1433 if len(expr) == 1 and isinstance(expr[0], list):
1434 expr = expr[0]
1435 jdf = self._jdf.selectExpr(self._jseq(expr))
1436 return DataFrame(jdf, self.sql_ctx)
1437
1438 @ignore_unicode_prefix
1439 @since(1.3)
1440 def filter(self, condition):
1441 """Filters rows using the given condition.
1442
1443 :func:`where` is an alias for :func:`filter`.
1444
1445 :param condition: a :class:`Column` of :class:`types.BooleanType`
1446 or a string of SQL expression.
1447
1448 >>> df.filter(df.age > 3).collect()
1449 [Row(age=5, name=u'Bob')]
1450 >>> df.where(df.age == 2).collect()
1451 [Row(age=2, name=u'Alice')]
1452
1453 >>> df.filter("age > 3").collect()
1454 [Row(age=5, name=u'Bob')]
1455 >>> df.where("age = 2").collect()
1456 [Row(age=2, name=u'Alice')]
1457 """
1458 if isinstance(condition, basestring):
1459 jdf = self._jdf.filter(condition)
1460 elif isinstance(condition, Column):
1461 jdf = self._jdf.filter(condition._jc)
1462 else:
1463 raise TypeError("condition should be string or Column")
1464 return DataFrame(jdf, self.sql_ctx)
1465
1466 @ignore_unicode_prefix
1467 @since(1.3)
1468 def groupBy(self, *cols):
1469 """Groups the :class:`DataFrame` using the specified columns,
1470 so we can run aggregation on them. See :class:`GroupedData`
1471 for all the available aggregate functions.
1472
1473 :func:`groupby` is an alias for :func:`groupBy`.
1474
1475 :param cols: list of columns to group by.
1476 Each element should be a column name (string) or an expression (:class:`Column`).
1477
1478 >>> df.groupBy().avg().collect()
1479 [Row(avg(age)=3.5)]
1480 >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect())
1481 [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
1482 >>> sorted(df.groupBy(df.name).avg().collect())
1483 [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
1484 >>> sorted(df.groupBy(['name', df.age]).count().collect())
1485 [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)]
1486 """
1487 jgd = self._jdf.groupBy(self._jcols(*cols))
1488 from pyspark.sql.group import GroupedData
1489 return GroupedData(jgd, self)
1490
1491 @since(1.4)
1492 def rollup(self, *cols):
1493 """
1494 Create a multi-dimensional rollup for the current :class:`DataFrame` using
1495 the specified columns, so we can run aggregation on them.
1496
1497 >>> df.rollup("name", df.age).count().orderBy("name", "age").show()
1498 +-----+----+-----+
1499 | name| age|count|
1500 +-----+----+-----+
1501 | null|null| 2|
1502 |Alice|null| 1|
1503 |Alice| 2| 1|
1504 | Bob|null| 1|
1505 | Bob| 5| 1|
1506 +-----+----+-----+
1507 """
1508 jgd = self._jdf.rollup(self._jcols(*cols))
1509 from pyspark.sql.group import GroupedData
1510 return GroupedData(jgd, self)
1511
1512 @since(1.4)
1513 def cube(self, *cols):
1514 """
1515 Create a multi-dimensional cube for the current :class:`DataFrame` using
1516 the specified columns, so we can run aggregations on them.
1517
1518 >>> df.cube("name", df.age).count().orderBy("name", "age").show()
1519 +-----+----+-----+
1520 | name| age|count|
1521 +-----+----+-----+
1522 | null|null| 2|
1523 | null| 2| 1|
1524 | null| 5| 1|
1525 |Alice|null| 1|
1526 |Alice| 2| 1|
1527 | Bob|null| 1|
1528 | Bob| 5| 1|
1529 +-----+----+-----+
1530 """
1531 jgd = self._jdf.cube(self._jcols(*cols))
1532 from pyspark.sql.group import GroupedData
1533 return GroupedData(jgd, self)
1534
1535 @since(1.3)
1536 def agg(self, *exprs):
1537 """ Aggregate on the entire :class:`DataFrame` without groups
1538 (shorthand for ``df.groupBy.agg()``).
1539
1540 >>> df.agg({"age": "max"}).collect()
1541 [Row(max(age)=5)]
1542 >>> from pyspark.sql import functions as F
1543 >>> df.agg(F.min(df.age)).collect()
1544 [Row(min(age)=2)]
1545 """
1546 return self.groupBy().agg(*exprs)
1547
1548 @since(2.0)
1549 def union(self, other):
1550 """ Return a new :class:`DataFrame` containing union of rows in this and another
1551 :class:`DataFrame`.
1552
1553 This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union
1554 (that does deduplication of elements), use this function followed by :func:`distinct`.
1555
1556 Also as standard in SQL, this function resolves columns by position (not by name).
1557 """
1558 return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)
1559
1560 @since(1.3)
1561 def unionAll(self, other):
1562 """ Return a new :class:`DataFrame` containing union of rows in this and another
1563 :class:`DataFrame`.
1564
1565 This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union
1566 (that does deduplication of elements), use this function followed by :func:`distinct`.
1567
1568 Also as standard in SQL, this function resolves columns by position (not by name).
1569 """
1570 return self.union(other)
1571
1572 @since(2.3)
1573 def unionByName(self, other):
1574 """ Returns a new :class:`DataFrame` containing union of rows in this and another
1575 :class:`DataFrame`.
1576
1577 This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set
1578 union (that does deduplication of elements), use this function followed by :func:`distinct`.
1579
1580 The difference between this function and :func:`union` is that this function
1581 resolves columns by name (not by position):
1582
1583 >>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"])
1584 >>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"])
1585 >>> df1.unionByName(df2).show()
1586 +----+----+----+
1587 |col0|col1|col2|
1588 +----+----+----+
1589 | 1| 2| 3|
1590 | 6| 4| 5|
1591 +----+----+----+
1592 """
1593 return DataFrame(self._jdf.unionByName(other._jdf), self.sql_ctx)
1594
1595 @since(1.3)
1596 def intersect(self, other):
1597 """ Return a new :class:`DataFrame` containing rows only in
1598 both this :class:`DataFrame` and another :class:`DataFrame`.
1599
1600 This is equivalent to `INTERSECT` in SQL.
1601 """
1602 return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
1603
1604 @since(2.4)
1605 def intersectAll(self, other):
1606 """ Return a new :class:`DataFrame` containing rows in both this :class:`DataFrame`
1607 and another :class:`DataFrame` while preserving duplicates.
1608
1609 This is equivalent to `INTERSECT ALL` in SQL.
1610 >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"])
1611 >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"])
1612
1613 >>> df1.intersectAll(df2).sort("C1", "C2").show()
1614 +---+---+
1615 | C1| C2|
1616 +---+---+
1617 | a| 1|
1618 | a| 1|
1619 | b| 3|
1620 +---+---+
1621
1622 Also as standard in SQL, this function resolves columns by position (not by name).
1623 """
1624 return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
1625
1626 @since(1.3)
1627 def subtract(self, other):
1628 """ Return a new :class:`DataFrame` containing rows in this :class:`DataFrame`
1629 but not in another :class:`DataFrame`.
1630
1631 This is equivalent to `EXCEPT DISTINCT` in SQL.
1632
1633 """
1634 return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
1635
1636 @since(1.4)
1637 def dropDuplicates(self, subset=None):
1638 """Return a new :class:`DataFrame` with duplicate rows removed,
1639 optionally only considering certain columns.
1640
1641 For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming
1642 :class:`DataFrame`, it will keep all data across triggers as intermediate state to drop
1643 duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can
1644 be and system will accordingly limit the state. In addition, too late data older than
1645 watermark will be dropped to avoid any possibility of duplicates.
1646
1647 :func:`drop_duplicates` is an alias for :func:`dropDuplicates`.
1648
1649 >>> from pyspark.sql import Row
1650 >>> df = sc.parallelize([ \\
1651 ... Row(name='Alice', age=5, height=80), \\
1652 ... Row(name='Alice', age=5, height=80), \\
1653 ... Row(name='Alice', age=10, height=80)]).toDF()
1654 >>> df.dropDuplicates().show()
1655 +---+------+-----+
1656 |age|height| name|
1657 +---+------+-----+
1658 | 5| 80|Alice|
1659 | 10| 80|Alice|
1660 +---+------+-----+
1661
1662 >>> df.dropDuplicates(['name', 'height']).show()
1663 +---+------+-----+
1664 |age|height| name|
1665 +---+------+-----+
1666 | 5| 80|Alice|
1667 +---+------+-----+
1668 """
1669 if subset is None:
1670 jdf = self._jdf.dropDuplicates()
1671 else:
1672 jdf = self._jdf.dropDuplicates(self._jseq(subset))
1673 return DataFrame(jdf, self.sql_ctx)
1674
1675 @since("1.3.1")
1676 def dropna(self, how='any', thresh=None, subset=None):
1677 """Returns a new :class:`DataFrame` omitting rows with null values.
1678 :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other.
1679
1680 :param how: 'any' or 'all'.
1681 If 'any', drop a row if it contains any nulls.
1682 If 'all', drop a row only if all its values are null.
1683 :param thresh: int, default None
1684 If specified, drop rows that have less than `thresh` non-null values.
1685 This overwrites the `how` parameter.
1686 :param subset: optional list of column names to consider.
1687
1688 >>> df4.na.drop().show()
1689 +---+------+-----+
1690 |age|height| name|
1691 +---+------+-----+
1692 | 10| 80|Alice|
1693 +---+------+-----+
1694 """
1695 if how is not None and how not in ['any', 'all']:
1696 raise ValueError("how ('" + how + "') should be 'any' or 'all'")
1697
1698 if subset is None:
1699 subset = self.columns
1700 elif isinstance(subset, basestring):
1701 subset = [subset]
1702 elif not isinstance(subset, (list, tuple)):
1703 raise ValueError("subset should be a list or tuple of column names")
1704
1705 if thresh is None:
1706 thresh = len(subset) if how == 'any' else 1
1707
1708 return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
1709
1710 @since("1.3.1")
1711 def fillna(self, value, subset=None):
1712 """Replace null values, alias for ``na.fill()``.
1713 :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.
1714
1715 :param value: int, long, float, string, bool or dict.
1716 Value to replace null values with.
1717 If the value is a dict, then `subset` is ignored and `value` must be a mapping
1718 from column name (string) to replacement value. The replacement value must be
1719 an int, long, float, boolean, or string.
1720 :param subset: optional list of column names to consider.
1721 Columns specified in subset that do not have matching data type are ignored.
1722 For example, if `value` is a string, and subset contains a non-string column,
1723 then the non-string column is simply ignored.
1724
1725 >>> df4.na.fill(50).show()
1726 +---+------+-----+
1727 |age|height| name|
1728 +---+------+-----+
1729 | 10| 80|Alice|
1730 | 5| 50| Bob|
1731 | 50| 50| Tom|
1732 | 50| 50| null|
1733 +---+------+-----+
1734
1735 >>> df5.na.fill(False).show()
1736 +----+-------+-----+
1737 | age| name| spy|
1738 +----+-------+-----+
1739 | 10| Alice|false|
1740 | 5| Bob|false|
1741 |null|Mallory| true|
1742 +----+-------+-----+
1743
1744 >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
1745 +---+------+-------+
1746 |age|height| name|
1747 +---+------+-------+
1748 | 10| 80| Alice|
1749 | 5| null| Bob|
1750 | 50| null| Tom|
1751 | 50| null|unknown|
1752 +---+------+-------+
1753 """
1754 if not isinstance(value, (float, int, long, basestring, bool, dict)):
1755 raise ValueError("value should be a float, int, long, string, bool or dict")
1756
1757
1758
1759
1760 if not isinstance(value, bool) and isinstance(value, (int, long)):
1761 value = float(value)
1762
1763 if isinstance(value, dict):
1764 return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
1765 elif subset is None:
1766 return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
1767 else:
1768 if isinstance(subset, basestring):
1769 subset = [subset]
1770 elif not isinstance(subset, (list, tuple)):
1771 raise ValueError("subset should be a list or tuple of column names")
1772
1773 return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
1774
1775 @since(1.4)
1776 def replace(self, to_replace, value=_NoValue, subset=None):
1777 """Returns a new :class:`DataFrame` replacing a value with another value.
1778 :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
1779 aliases of each other.
1780 Values to_replace and value must have the same type and can only be numerics, booleans,
1781 or strings. Value can have None. When replacing, the new value will be cast
1782 to the type of the existing column.
1783 For numeric replacements all values to be replaced should have unique
1784 floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`)
1785 and arbitrary replacement will be used.
1786
1787 :param to_replace: bool, int, long, float, string, list or dict.
1788 Value to be replaced.
1789 If the value is a dict, then `value` is ignored or can be omitted, and `to_replace`
1790 must be a mapping between a value and a replacement.
1791 :param value: bool, int, long, float, string, list or None.
1792 The replacement value must be a bool, int, long, float, string or None. If `value` is a
1793 list, `value` should be of the same length and type as `to_replace`.
1794 If `value` is a scalar and `to_replace` is a sequence, then `value` is
1795 used as a replacement for each item in `to_replace`.
1796 :param subset: optional list of column names to consider.
1797 Columns specified in subset that do not have matching data type are ignored.
1798 For example, if `value` is a string, and subset contains a non-string column,
1799 then the non-string column is simply ignored.
1800
1801 >>> df4.na.replace(10, 20).show()
1802 +----+------+-----+
1803 | age|height| name|
1804 +----+------+-----+
1805 | 20| 80|Alice|
1806 | 5| null| Bob|
1807 |null| null| Tom|
1808 |null| null| null|
1809 +----+------+-----+
1810
1811 >>> df4.na.replace('Alice', None).show()
1812 +----+------+----+
1813 | age|height|name|
1814 +----+------+----+
1815 | 10| 80|null|
1816 | 5| null| Bob|
1817 |null| null| Tom|
1818 |null| null|null|
1819 +----+------+----+
1820
1821 >>> df4.na.replace({'Alice': None}).show()
1822 +----+------+----+
1823 | age|height|name|
1824 +----+------+----+
1825 | 10| 80|null|
1826 | 5| null| Bob|
1827 |null| null| Tom|
1828 |null| null|null|
1829 +----+------+----+
1830
1831 >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
1832 +----+------+----+
1833 | age|height|name|
1834 +----+------+----+
1835 | 10| 80| A|
1836 | 5| null| B|
1837 |null| null| Tom|
1838 |null| null|null|
1839 +----+------+----+
1840 """
1841 if value is _NoValue:
1842 if isinstance(to_replace, dict):
1843 value = None
1844 else:
1845 raise TypeError("value argument is required when to_replace is not a dictionary.")
1846
1847
1848 def all_of(types):
1849 """Given a type or tuple of types and a sequence of xs
1850 check if each x is instance of type(s)
1851
1852 >>> all_of(bool)([True, False])
1853 True
1854 >>> all_of(basestring)(["a", 1])
1855 False
1856 """
1857 def all_of_(xs):
1858 return all(isinstance(x, types) for x in xs)
1859 return all_of_
1860
1861 all_of_bool = all_of(bool)
1862 all_of_str = all_of(basestring)
1863 all_of_numeric = all_of((float, int, long))
1864
1865
1866 valid_types = (bool, float, int, long, basestring, list, tuple)
1867 if not isinstance(to_replace, valid_types + (dict, )):
1868 raise ValueError(
1869 "to_replace should be a bool, float, int, long, string, list, tuple, or dict. "
1870 "Got {0}".format(type(to_replace)))
1871
1872 if not isinstance(value, valid_types) and value is not None \
1873 and not isinstance(to_replace, dict):
1874 raise ValueError("If to_replace is not a dict, value should be "
1875 "a bool, float, int, long, string, list, tuple or None. "
1876 "Got {0}".format(type(value)))
1877
1878 if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
1879 if len(to_replace) != len(value):
1880 raise ValueError("to_replace and value lists should be of the same length. "
1881 "Got {0} and {1}".format(len(to_replace), len(value)))
1882
1883 if not (subset is None or isinstance(subset, (list, tuple, basestring))):
1884 raise ValueError("subset should be a list or tuple of column names, "
1885 "column name or None. Got {0}".format(type(subset)))
1886
1887
1888 if isinstance(to_replace, (float, int, long, basestring)):
1889 to_replace = [to_replace]
1890
1891 if isinstance(to_replace, dict):
1892 rep_dict = to_replace
1893 if value is not None:
1894 warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
1895 else:
1896 if isinstance(value, (float, int, long, basestring)) or value is None:
1897 value = [value for _ in range(len(to_replace))]
1898 rep_dict = dict(zip(to_replace, value))
1899
1900 if isinstance(subset, basestring):
1901 subset = [subset]
1902
1903
1904 if not any(all_of_type(rep_dict.keys())
1905 and all_of_type(x for x in rep_dict.values() if x is not None)
1906 for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):
1907 raise ValueError("Mixed type replacements are not supported")
1908
1909 if subset is None:
1910 return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
1911 else:
1912 return DataFrame(
1913 self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
1914
1915 @since(2.0)
1916 def approxQuantile(self, col, probabilities, relativeError):
1917 """
1918 Calculates the approximate quantiles of numerical columns of a
1919 :class:`DataFrame`.
1920
1921 The result of this algorithm has the following deterministic bound:
1922 If the :class:`DataFrame` has N elements and if we request the quantile at
1923 probability `p` up to error `err`, then the algorithm will return
1924 a sample `x` from the :class:`DataFrame` so that the *exact* rank of `x` is
1925 close to (p * N). More precisely,
1926
1927 floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
1928
1929 This method implements a variation of the Greenwald-Khanna
1930 algorithm (with some speed optimizations). The algorithm was first
1931 present in [[https://doi.org/10.1145/375663.375670
1932 Space-efficient Online Computation of Quantile Summaries]]
1933 by Greenwald and Khanna.
1934
1935 Note that null values will be ignored in numerical columns before calculation.
1936 For columns only containing null values, an empty list is returned.
1937
1938 :param col: str, list.
1939 Can be a single column name, or a list of names for multiple columns.
1940 :param probabilities: a list of quantile probabilities
1941 Each number must belong to [0, 1].
1942 For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
1943 :param relativeError: The relative target precision to achieve
1944 (>= 0). If set to zero, the exact quantiles are computed, which
1945 could be very expensive. Note that values greater than 1 are
1946 accepted but give the same result as 1.
1947 :return: the approximate quantiles at the given probabilities. If
1948 the input `col` is a string, the output is a list of floats. If the
1949 input `col` is a list or tuple of strings, the output is also a
1950 list, but each element in it is a list of floats, i.e., the output
1951 is a list of list of floats.
1952
1953 .. versionchanged:: 2.2
1954 Added support for multiple columns.
1955 """
1956
1957 if not isinstance(col, (basestring, list, tuple)):
1958 raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
1959
1960 isStr = isinstance(col, basestring)
1961
1962 if isinstance(col, tuple):
1963 col = list(col)
1964 elif isStr:
1965 col = [col]
1966
1967 for c in col:
1968 if not isinstance(c, basestring):
1969 raise ValueError("columns should be strings, but got %r" % type(c))
1970 col = _to_list(self._sc, col)
1971
1972 if not isinstance(probabilities, (list, tuple)):
1973 raise ValueError("probabilities should be a list or tuple")
1974 if isinstance(probabilities, tuple):
1975 probabilities = list(probabilities)
1976 for p in probabilities:
1977 if not isinstance(p, (float, int, long)) or p < 0 or p > 1:
1978 raise ValueError("probabilities should be numerical (float, int, long) in [0,1].")
1979 probabilities = _to_list(self._sc, probabilities)
1980
1981 if not isinstance(relativeError, (float, int, long)) or relativeError < 0:
1982 raise ValueError("relativeError should be numerical (float, int, long) >= 0.")
1983 relativeError = float(relativeError)
1984
1985 jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError)
1986 jaq_list = [list(j) for j in jaq]
1987 return jaq_list[0] if isStr else jaq_list
1988
1989 @since(1.4)
1990 def corr(self, col1, col2, method=None):
1991 """
1992 Calculates the correlation of two columns of a :class:`DataFrame` as a double value.
1993 Currently only supports the Pearson Correlation Coefficient.
1994 :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other.
1995
1996 :param col1: The name of the first column
1997 :param col2: The name of the second column
1998 :param method: The correlation method. Currently only supports "pearson"
1999 """
2000 if not isinstance(col1, basestring):
2001 raise ValueError("col1 should be a string.")
2002 if not isinstance(col2, basestring):
2003 raise ValueError("col2 should be a string.")
2004 if not method:
2005 method = "pearson"
2006 if not method == "pearson":
2007 raise ValueError("Currently only the calculation of the Pearson Correlation " +
2008 "coefficient is supported.")
2009 return self._jdf.stat().corr(col1, col2, method)
2010
2011 @since(1.4)
2012 def cov(self, col1, col2):
2013 """
2014 Calculate the sample covariance for the given columns, specified by their names, as a
2015 double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases.
2016
2017 :param col1: The name of the first column
2018 :param col2: The name of the second column
2019 """
2020 if not isinstance(col1, basestring):
2021 raise ValueError("col1 should be a string.")
2022 if not isinstance(col2, basestring):
2023 raise ValueError("col2 should be a string.")
2024 return self._jdf.stat().cov(col1, col2)
2025
2026 @since(1.4)
2027 def crosstab(self, col1, col2):
2028 """
2029 Computes a pair-wise frequency table of the given columns. Also known as a contingency
2030 table. The number of distinct values for each column should be less than 1e4. At most 1e6
2031 non-zero pair frequencies will be returned.
2032 The first column of each row will be the distinct values of `col1` and the column names
2033 will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`.
2034 Pairs that have no occurrences will have zero as their counts.
2035 :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
2036
2037 :param col1: The name of the first column. Distinct items will make the first item of
2038 each row.
2039 :param col2: The name of the second column. Distinct items will make the column names
2040 of the :class:`DataFrame`.
2041 """
2042 if not isinstance(col1, basestring):
2043 raise ValueError("col1 should be a string.")
2044 if not isinstance(col2, basestring):
2045 raise ValueError("col2 should be a string.")
2046 return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
2047
2048 @since(1.4)
2049 def freqItems(self, cols, support=None):
2050 """
2051 Finding frequent items for columns, possibly with false positives. Using the
2052 frequent element count algorithm described in
2053 "https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
2054 :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
2055
2056 .. note:: This function is meant for exploratory data analysis, as we make no
2057 guarantee about the backward compatibility of the schema of the resulting
2058 :class:`DataFrame`.
2059
2060 :param cols: Names of the columns to calculate frequent items for as a list or tuple of
2061 strings.
2062 :param support: The frequency with which to consider an item 'frequent'. Default is 1%.
2063 The support must be greater than 1e-4.
2064 """
2065 if isinstance(cols, tuple):
2066 cols = list(cols)
2067 if not isinstance(cols, list):
2068 raise ValueError("cols must be a list or tuple of column names as strings.")
2069 if not support:
2070 support = 0.01
2071 return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx)
2072
2073 @ignore_unicode_prefix
2074 @since(1.3)
2075 def withColumn(self, colName, col):
2076 """
2077 Returns a new :class:`DataFrame` by adding a column or replacing the
2078 existing column that has the same name.
2079
2080 The column expression must be an expression over this :class:`DataFrame`; attempting to add
2081 a column from some other :class:`DataFrame` will raise an error.
2082
2083 :param colName: string, name of the new column.
2084 :param col: a :class:`Column` expression for the new column.
2085
2086 .. note:: This method introduces a projection internally. Therefore, calling it multiple
2087 times, for instance, via loops in order to add multiple columns can generate big
2088 plans which can cause performance issues and even `StackOverflowException`.
2089 To avoid this, use :func:`select` with the multiple columns at once.
2090
2091 >>> df.withColumn('age2', df.age + 2).collect()
2092 [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
2093
2094 """
2095 assert isinstance(col, Column), "col should be Column"
2096 return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
2097
2098 @ignore_unicode_prefix
2099 @since(1.3)
2100 def withColumnRenamed(self, existing, new):
2101 """Returns a new :class:`DataFrame` by renaming an existing column.
2102 This is a no-op if schema doesn't contain the given column name.
2103
2104 :param existing: string, name of the existing column to rename.
2105 :param new: string, new name of the column.
2106
2107 >>> df.withColumnRenamed('age', 'age2').collect()
2108 [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
2109 """
2110 return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx)
2111
2112 @since(1.4)
2113 @ignore_unicode_prefix
2114 def drop(self, *cols):
2115 """Returns a new :class:`DataFrame` that drops the specified column.
2116 This is a no-op if schema doesn't contain the given column name(s).
2117
2118 :param cols: a string name of the column to drop, or a
2119 :class:`Column` to drop, or a list of string name of the columns to drop.
2120
2121 >>> df.drop('age').collect()
2122 [Row(name=u'Alice'), Row(name=u'Bob')]
2123
2124 >>> df.drop(df.age).collect()
2125 [Row(name=u'Alice'), Row(name=u'Bob')]
2126
2127 >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()
2128 [Row(age=5, height=85, name=u'Bob')]
2129
2130 >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()
2131 [Row(age=5, name=u'Bob', height=85)]
2132
2133 >>> df.join(df2, 'name', 'inner').drop('age', 'height').collect()
2134 [Row(name=u'Bob')]
2135 """
2136 if len(cols) == 1:
2137 col = cols[0]
2138 if isinstance(col, basestring):
2139 jdf = self._jdf.drop(col)
2140 elif isinstance(col, Column):
2141 jdf = self._jdf.drop(col._jc)
2142 else:
2143 raise TypeError("col should be a string or a Column")
2144 else:
2145 for col in cols:
2146 if not isinstance(col, basestring):
2147 raise TypeError("each col in the param list should be a string")
2148 jdf = self._jdf.drop(self._jseq(cols))
2149
2150 return DataFrame(jdf, self.sql_ctx)
2151
2152 @ignore_unicode_prefix
2153 def toDF(self, *cols):
2154 """Returns a new :class:`DataFrame` that with new specified column names
2155
2156 :param cols: list of new column names (string)
2157
2158 >>> df.toDF('f1', 'f2').collect()
2159 [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')]
2160 """
2161 jdf = self._jdf.toDF(self._jseq(cols))
2162 return DataFrame(jdf, self.sql_ctx)
2163
2164 @since(3.0)
2165 def transform(self, func):
2166 """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations.
2167
2168 :param func: a function that takes and returns a :class:`DataFrame`.
2169
2170 >>> from pyspark.sql.functions import col
2171 >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"])
2172 >>> def cast_all_to_int(input_df):
2173 ... return input_df.select([col(col_name).cast("int") for col_name in input_df.columns])
2174 >>> def sort_columns_asc(input_df):
2175 ... return input_df.select(*sorted(input_df.columns))
2176 >>> df.transform(cast_all_to_int).transform(sort_columns_asc).show()
2177 +-----+---+
2178 |float|int|
2179 +-----+---+
2180 | 1| 1|
2181 | 2| 2|
2182 +-----+---+
2183 """
2184 result = func(self)
2185 assert isinstance(result, DataFrame), "Func returned an instance of type [%s], " \
2186 "should have been DataFrame." % type(result)
2187 return result
2188
2189 where = copy_func(
2190 filter,
2191 sinceversion=1.3,
2192 doc=":func:`where` is an alias for :func:`filter`.")
2193
2194
2195
2196
2197
2198
2199 groupby = copy_func(
2200 groupBy,
2201 sinceversion=1.4,
2202 doc=":func:`groupby` is an alias for :func:`groupBy`.")
2203
2204 drop_duplicates = copy_func(
2205 dropDuplicates,
2206 sinceversion=1.4,
2207 doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.")
2208
2209
2210 def _to_scala_map(sc, jm):
2211 """
2212 Convert a dict into a JVM Map.
2213 """
2214 return sc._jvm.PythonUtils.toScalaMap(jm)
2215
2216
2217 class DataFrameNaFunctions(object):
2218 """Functionality for working with missing data in :class:`DataFrame`.
2219
2220 .. versionadded:: 1.4
2221 """
2222
2223 def __init__(self, df):
2224 self.df = df
2225
2226 def drop(self, how='any', thresh=None, subset=None):
2227 return self.df.dropna(how=how, thresh=thresh, subset=subset)
2228
2229 drop.__doc__ = DataFrame.dropna.__doc__
2230
2231 def fill(self, value, subset=None):
2232 return self.df.fillna(value=value, subset=subset)
2233
2234 fill.__doc__ = DataFrame.fillna.__doc__
2235
2236 def replace(self, to_replace, value=_NoValue, subset=None):
2237 return self.df.replace(to_replace, value, subset)
2238
2239 replace.__doc__ = DataFrame.replace.__doc__
2240
2241
2242 class DataFrameStatFunctions(object):
2243 """Functionality for statistic functions with :class:`DataFrame`.
2244
2245 .. versionadded:: 1.4
2246 """
2247
2248 def __init__(self, df):
2249 self.df = df
2250
2251 def approxQuantile(self, col, probabilities, relativeError):
2252 return self.df.approxQuantile(col, probabilities, relativeError)
2253
2254 approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__
2255
2256 def corr(self, col1, col2, method=None):
2257 return self.df.corr(col1, col2, method)
2258
2259 corr.__doc__ = DataFrame.corr.__doc__
2260
2261 def cov(self, col1, col2):
2262 return self.df.cov(col1, col2)
2263
2264 cov.__doc__ = DataFrame.cov.__doc__
2265
2266 def crosstab(self, col1, col2):
2267 return self.df.crosstab(col1, col2)
2268
2269 crosstab.__doc__ = DataFrame.crosstab.__doc__
2270
2271 def freqItems(self, cols, support=None):
2272 return self.df.freqItems(cols, support)
2273
2274 freqItems.__doc__ = DataFrame.freqItems.__doc__
2275
2276 def sampleBy(self, col, fractions, seed=None):
2277 return self.df.sampleBy(col, fractions, seed)
2278
2279 sampleBy.__doc__ = DataFrame.sampleBy.__doc__
2280
2281
2282 def _test():
2283 import doctest
2284 from pyspark.context import SparkContext
2285 from pyspark.sql import Row, SQLContext, SparkSession
2286 import pyspark.sql.dataframe
2287 from pyspark.sql.functions import from_unixtime
2288 globs = pyspark.sql.dataframe.__dict__.copy()
2289 sc = SparkContext('local[4]', 'PythonTest')
2290 globs['sc'] = sc
2291 globs['sqlContext'] = SQLContext(sc)
2292 globs['spark'] = SparkSession(sc)
2293 globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
2294 .toDF(StructType([StructField('age', IntegerType()),
2295 StructField('name', StringType())]))
2296 globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
2297 globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
2298 Row(name='Bob', age=5)]).toDF()
2299 globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
2300 Row(name='Bob', age=5, height=None),
2301 Row(name='Tom', age=None, height=None),
2302 Row(name=None, age=None, height=None)]).toDF()
2303 globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10),
2304 Row(name='Bob', spy=None, age=5),
2305 Row(name='Mallory', spy=True, age=None)]).toDF()
2306 globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846),
2307 Row(name='Bob', time=1479442946)]).toDF()
2308
2309 (failure_count, test_count) = doctest.testmod(
2310 pyspark.sql.dataframe, globs=globs,
2311 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
2312 globs['sc'].stop()
2313 if failure_count:
2314 sys.exit(-1)
2315
2316
2317 if __name__ == "__main__":
2318 _test()