0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 A collections of builtin functions
0020 """
0021 import sys
0022 import functools
0023 import warnings
0024
0025 if sys.version < "3":
0026 from itertools import imap as map
0027
0028 if sys.version >= '3':
0029 basestring = str
0030
0031 from pyspark import since, SparkContext
0032 from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
0033 from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \
0034 _create_column_from_name
0035 from pyspark.sql.dataframe import DataFrame
0036 from pyspark.sql.types import StringType, DataType
0037
0038 from pyspark.sql.udf import UserDefinedFunction, _create_udf
0039
0040 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
0041 from pyspark.sql.utils import to_str
0042
0043
0044
0045
0046
0047
0048
0049
0050 def _create_function(name, doc=""):
0051 """Create a PySpark function by its name"""
0052 def _(col):
0053 sc = SparkContext._active_spark_context
0054 jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
0055 return Column(jc)
0056 _.__name__ = name
0057 _.__doc__ = doc
0058 return _
0059
0060
0061 def _create_function_over_column(name, doc=""):
0062 """Similar with `_create_function` but creates a PySpark function that takes a column
0063 (as string as well). This is mainly for PySpark functions to take strings as
0064 column names.
0065 """
0066 def _(col):
0067 sc = SparkContext._active_spark_context
0068 jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
0069 return Column(jc)
0070 _.__name__ = name
0071 _.__doc__ = doc
0072 return _
0073
0074
0075 def _wrap_deprecated_function(func, message):
0076 """ Wrap the deprecated function to print out deprecation warnings"""
0077 def _(col):
0078 warnings.warn(message, DeprecationWarning)
0079 return func(col)
0080 return functools.wraps(func)(_)
0081
0082
0083 def _create_binary_mathfunction(name, doc=""):
0084 """ Create a binary mathfunction by name"""
0085 def _(col1, col2):
0086 sc = SparkContext._active_spark_context
0087
0088
0089 if isinstance(col1, Column):
0090 arg1 = col1._jc
0091 elif isinstance(col1, basestring):
0092 arg1 = _create_column_from_name(col1)
0093 else:
0094 arg1 = float(col1)
0095
0096 if isinstance(col2, Column):
0097 arg2 = col2._jc
0098 elif isinstance(col2, basestring):
0099 arg2 = _create_column_from_name(col2)
0100 else:
0101 arg2 = float(col2)
0102
0103 jc = getattr(sc._jvm.functions, name)(arg1, arg2)
0104 return Column(jc)
0105 _.__name__ = name
0106 _.__doc__ = doc
0107 return _
0108
0109
0110 def _create_window_function(name, doc=''):
0111 """ Create a window function by name """
0112 def _():
0113 sc = SparkContext._active_spark_context
0114 jc = getattr(sc._jvm.functions, name)()
0115 return Column(jc)
0116 _.__name__ = name
0117 _.__doc__ = 'Window function: ' + doc
0118 return _
0119
0120
0121 def _options_to_str(options):
0122 return {key: to_str(value) for (key, value) in options.items()}
0123
0124 _lit_doc = """
0125 Creates a :class:`Column` of literal value.
0126
0127 >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)
0128 [Row(height=5, spark_user=True)]
0129 """
0130 _functions = {
0131 'lit': _lit_doc,
0132 'col': 'Returns a :class:`Column` based on the given column name.',
0133 'column': 'Returns a :class:`Column` based on the given column name.',
0134 'asc': 'Returns a sort expression based on the ascending order of the given column name.',
0135 'desc': 'Returns a sort expression based on the descending order of the given column name.',
0136 }
0137
0138 _functions_over_column = {
0139 'sqrt': 'Computes the square root of the specified float value.',
0140 'abs': 'Computes the absolute value.',
0141
0142 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
0143 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
0144 'count': 'Aggregate function: returns the number of items in a group.',
0145 'sum': 'Aggregate function: returns the sum of all values in the expression.',
0146 'avg': 'Aggregate function: returns the average of the values in a group.',
0147 'mean': 'Aggregate function: returns the average of the values in a group.',
0148 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
0149 }
0150
0151 _functions_1_4_over_column = {
0152
0153 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`',
0154 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`',
0155 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`',
0156 'cbrt': 'Computes the cube-root of the given value.',
0157 'ceil': 'Computes the ceiling of the given value.',
0158 'cos': """:param col: angle in radians
0159 :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""",
0160 'cosh': """:param col: hyperbolic angle
0161 :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""",
0162 'exp': 'Computes the exponential of the given value.',
0163 'expm1': 'Computes the exponential of the given value minus one.',
0164 'floor': 'Computes the floor of the given value.',
0165 'log': 'Computes the natural logarithm of the given value.',
0166 'log10': 'Computes the logarithm of the given value in Base 10.',
0167 'log1p': 'Computes the natural logarithm of the given value plus one.',
0168 'rint': 'Returns the double value that is closest in value to the argument and' +
0169 ' is equal to a mathematical integer.',
0170 'signum': 'Computes the signum of the given value.',
0171 'sin': """:param col: angle in radians
0172 :return: sine of the angle, as if computed by `java.lang.Math.sin()`""",
0173 'sinh': """:param col: hyperbolic angle
0174 :return: hyperbolic sine of the given value,
0175 as if computed by `java.lang.Math.sinh()`""",
0176 'tan': """:param col: angle in radians
0177 :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""",
0178 'tanh': """:param col: hyperbolic angle
0179 :return: hyperbolic tangent of the given value,
0180 as if computed by `java.lang.Math.tanh()`""",
0181 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.',
0182 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.',
0183 'bitwiseNOT': 'Computes bitwise not.',
0184 }
0185
0186 _functions_2_4 = {
0187 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' +
0188 ' column name, and null values return before non-null values.',
0189 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' +
0190 ' column name, and null values appear after non-null values.',
0191 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' +
0192 ' column name, and null values appear before non-null values.',
0193 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' +
0194 ' column name, and null values appear after non-null values',
0195 }
0196
0197 _collect_list_doc = """
0198 Aggregate function: returns a list of objects with duplicates.
0199
0200 .. note:: The function is non-deterministic because the order of collected results depends
0201 on the order of the rows which may be non-deterministic after a shuffle.
0202
0203 >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
0204 >>> df2.agg(collect_list('age')).collect()
0205 [Row(collect_list(age)=[2, 5, 5])]
0206 """
0207 _collect_set_doc = """
0208 Aggregate function: returns a set of objects with duplicate elements eliminated.
0209
0210 .. note:: The function is non-deterministic because the order of collected results depends
0211 on the order of the rows which may be non-deterministic after a shuffle.
0212
0213 >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
0214 >>> df2.agg(collect_set('age')).collect()
0215 [Row(collect_set(age)=[5, 2])]
0216 """
0217 _functions_1_6_over_column = {
0218
0219 'stddev': 'Aggregate function: alias for stddev_samp.',
0220 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' +
0221 ' the expression in a group.',
0222 'stddev_pop': 'Aggregate function: returns population standard deviation of' +
0223 ' the expression in a group.',
0224 'variance': 'Aggregate function: alias for var_samp.',
0225 'var_samp': 'Aggregate function: returns the unbiased sample variance of' +
0226 ' the values in a group.',
0227 'var_pop': 'Aggregate function: returns the population variance of the values in a group.',
0228 'skewness': 'Aggregate function: returns the skewness of the values in a group.',
0229 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.',
0230 'collect_list': _collect_list_doc,
0231 'collect_set': _collect_set_doc
0232 }
0233
0234 _functions_2_1_over_column = {
0235
0236 'degrees': """
0237 Converts an angle measured in radians to an approximately equivalent angle
0238 measured in degrees.
0239
0240 :param col: angle in radians
0241 :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()`
0242 """,
0243 'radians': """
0244 Converts an angle measured in degrees to an approximately equivalent angle
0245 measured in radians.
0246
0247 :param col: angle in degrees
0248 :return: angle in radians, as if computed by `java.lang.Math.toRadians()`
0249 """,
0250 }
0251
0252
0253 _binary_mathfunctions = {
0254 'atan2': """
0255 :param col1: coordinate on y-axis
0256 :param col2: coordinate on x-axis
0257 :return: the `theta` component of the point
0258 (`r`, `theta`)
0259 in polar coordinates that corresponds to the point
0260 (`x`, `y`) in Cartesian coordinates,
0261 as if computed by `java.lang.Math.atan2()`
0262 """,
0263 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.',
0264 'pow': 'Returns the value of the first argument raised to the power of the second argument.',
0265 }
0266
0267 _window_functions = {
0268 'row_number':
0269 """returns a sequential number starting at 1 within a window partition.""",
0270 'dense_rank':
0271 """returns the rank of rows within a window partition, without any gaps.
0272
0273 The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking
0274 sequence when there are ties. That is, if you were ranking a competition using dense_rank
0275 and had three people tie for second place, you would say that all three were in second
0276 place and that the next person came in third. Rank would give me sequential numbers, making
0277 the person that came in third place (after the ties) would register as coming in fifth.
0278
0279 This is equivalent to the DENSE_RANK function in SQL.""",
0280 'rank':
0281 """returns the rank of rows within a window partition.
0282
0283 The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking
0284 sequence when there are ties. That is, if you were ranking a competition using dense_rank
0285 and had three people tie for second place, you would say that all three were in second
0286 place and that the next person came in third. Rank would give me sequential numbers, making
0287 the person that came in third place (after the ties) would register as coming in fifth.
0288
0289 This is equivalent to the RANK function in SQL.""",
0290 'cume_dist':
0291 """returns the cumulative distribution of values within a window partition,
0292 i.e. the fraction of rows that are below the current row.""",
0293 'percent_rank':
0294 """returns the relative rank (i.e. percentile) of rows within a window partition.""",
0295 }
0296
0297
0298 _functions_deprecated = {
0299 'toDegrees': 'Deprecated in 2.1, use degrees instead.',
0300 'toRadians': 'Deprecated in 2.1, use radians instead.',
0301 }
0302
0303 for _name, _doc in _functions.items():
0304 globals()[_name] = since(1.3)(_create_function(_name, _doc))
0305 for _name, _doc in _functions_over_column.items():
0306 globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc))
0307 for _name, _doc in _functions_1_4_over_column.items():
0308 globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc))
0309 for _name, _doc in _binary_mathfunctions.items():
0310 globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
0311 for _name, _doc in _window_functions.items():
0312 globals()[_name] = since(1.6)(_create_window_function(_name, _doc))
0313 for _name, _doc in _functions_1_6_over_column.items():
0314 globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc))
0315 for _name, _doc in _functions_2_1_over_column.items():
0316 globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc))
0317 for _name, _message in _functions_deprecated.items():
0318 globals()[_name] = _wrap_deprecated_function(globals()[_name], _message)
0319 for _name, _doc in _functions_2_4.items():
0320 globals()[_name] = since(2.4)(_create_function(_name, _doc))
0321 del _name, _doc
0322
0323
0324 @since(1.3)
0325 def approxCountDistinct(col, rsd=None):
0326 """
0327 .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead.
0328 """
0329 warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning)
0330 return approx_count_distinct(col, rsd)
0331
0332
0333 @since(2.1)
0334 def approx_count_distinct(col, rsd=None):
0335 """Aggregate function: returns a new :class:`Column` for approximate distinct count of
0336 column `col`.
0337
0338 :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more
0339 efficient to use :func:`countDistinct`
0340
0341 >>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect()
0342 [Row(distinct_ages=2)]
0343 """
0344 sc = SparkContext._active_spark_context
0345 if rsd is None:
0346 jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col))
0347 else:
0348 jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col), rsd)
0349 return Column(jc)
0350
0351
0352 @since(1.6)
0353 def broadcast(df):
0354 """Marks a DataFrame as small enough for use in broadcast joins."""
0355
0356 sc = SparkContext._active_spark_context
0357 return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx)
0358
0359
0360 @since(1.4)
0361 def coalesce(*cols):
0362 """Returns the first column that is not null.
0363
0364 >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b"))
0365 >>> cDf.show()
0366 +----+----+
0367 | a| b|
0368 +----+----+
0369 |null|null|
0370 | 1|null|
0371 |null| 2|
0372 +----+----+
0373
0374 >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
0375 +--------------+
0376 |coalesce(a, b)|
0377 +--------------+
0378 | null|
0379 | 1|
0380 | 2|
0381 +--------------+
0382
0383 >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
0384 +----+----+----------------+
0385 | a| b|coalesce(a, 0.0)|
0386 +----+----+----------------+
0387 |null|null| 0.0|
0388 | 1|null| 1.0|
0389 |null| 2| 0.0|
0390 +----+----+----------------+
0391 """
0392 sc = SparkContext._active_spark_context
0393 jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column))
0394 return Column(jc)
0395
0396
0397 @since(1.6)
0398 def corr(col1, col2):
0399 """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1``
0400 and ``col2``.
0401
0402 >>> a = range(20)
0403 >>> b = [2 * x for x in range(20)]
0404 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"])
0405 >>> df.agg(corr("a", "b").alias('c')).collect()
0406 [Row(c=1.0)]
0407 """
0408 sc = SparkContext._active_spark_context
0409 return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2)))
0410
0411
0412 @since(2.0)
0413 def covar_pop(col1, col2):
0414 """Returns a new :class:`Column` for the population covariance of ``col1`` and ``col2``.
0415
0416 >>> a = [1] * 10
0417 >>> b = [1] * 10
0418 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"])
0419 >>> df.agg(covar_pop("a", "b").alias('c')).collect()
0420 [Row(c=0.0)]
0421 """
0422 sc = SparkContext._active_spark_context
0423 return Column(sc._jvm.functions.covar_pop(_to_java_column(col1), _to_java_column(col2)))
0424
0425
0426 @since(2.0)
0427 def covar_samp(col1, col2):
0428 """Returns a new :class:`Column` for the sample covariance of ``col1`` and ``col2``.
0429
0430 >>> a = [1] * 10
0431 >>> b = [1] * 10
0432 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"])
0433 >>> df.agg(covar_samp("a", "b").alias('c')).collect()
0434 [Row(c=0.0)]
0435 """
0436 sc = SparkContext._active_spark_context
0437 return Column(sc._jvm.functions.covar_samp(_to_java_column(col1), _to_java_column(col2)))
0438
0439
0440 @since(1.3)
0441 def countDistinct(col, *cols):
0442 """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
0443
0444 >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
0445 [Row(c=2)]
0446
0447 >>> df.agg(countDistinct("age", "name").alias('c')).collect()
0448 [Row(c=2)]
0449 """
0450 sc = SparkContext._active_spark_context
0451 jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
0452 return Column(jc)
0453
0454
0455 @since(1.3)
0456 def first(col, ignorenulls=False):
0457 """Aggregate function: returns the first value in a group.
0458
0459 The function by default returns the first values it sees. It will return the first non-null
0460 value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
0461
0462 .. note:: The function is non-deterministic because its results depends on the order of the
0463 rows which may be non-deterministic after a shuffle.
0464 """
0465 sc = SparkContext._active_spark_context
0466 jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
0467 return Column(jc)
0468
0469
0470 @since(2.0)
0471 def grouping(col):
0472 """
0473 Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
0474 or not, returns 1 for aggregated or 0 for not aggregated in the result set.
0475
0476 >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show()
0477 +-----+--------------+--------+
0478 | name|grouping(name)|sum(age)|
0479 +-----+--------------+--------+
0480 | null| 1| 7|
0481 |Alice| 0| 2|
0482 | Bob| 0| 5|
0483 +-----+--------------+--------+
0484 """
0485 sc = SparkContext._active_spark_context
0486 jc = sc._jvm.functions.grouping(_to_java_column(col))
0487 return Column(jc)
0488
0489
0490 @since(2.0)
0491 def grouping_id(*cols):
0492 """
0493 Aggregate function: returns the level of grouping, equals to
0494
0495 (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
0496
0497 .. note:: The list of columns should match with grouping columns exactly, or empty (means all
0498 the grouping columns).
0499
0500 >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
0501 +-----+-------------+--------+
0502 | name|grouping_id()|sum(age)|
0503 +-----+-------------+--------+
0504 | null| 1| 7|
0505 |Alice| 0| 2|
0506 | Bob| 0| 5|
0507 +-----+-------------+--------+
0508 """
0509 sc = SparkContext._active_spark_context
0510 jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
0511 return Column(jc)
0512
0513
0514 @since(1.6)
0515 def input_file_name():
0516 """Creates a string column for the file name of the current Spark task.
0517 """
0518 sc = SparkContext._active_spark_context
0519 return Column(sc._jvm.functions.input_file_name())
0520
0521
0522 @since(1.6)
0523 def isnan(col):
0524 """An expression that returns true iff the column is NaN.
0525
0526 >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b"))
0527 >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect()
0528 [Row(r1=False, r2=False), Row(r1=True, r2=True)]
0529 """
0530 sc = SparkContext._active_spark_context
0531 return Column(sc._jvm.functions.isnan(_to_java_column(col)))
0532
0533
0534 @since(1.6)
0535 def isnull(col):
0536 """An expression that returns true iff the column is null.
0537
0538 >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b"))
0539 >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect()
0540 [Row(r1=False, r2=False), Row(r1=True, r2=True)]
0541 """
0542 sc = SparkContext._active_spark_context
0543 return Column(sc._jvm.functions.isnull(_to_java_column(col)))
0544
0545
0546 @since(1.3)
0547 def last(col, ignorenulls=False):
0548 """Aggregate function: returns the last value in a group.
0549
0550 The function by default returns the last values it sees. It will return the last non-null
0551 value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
0552
0553 .. note:: The function is non-deterministic because its results depends on the order of the
0554 rows which may be non-deterministic after a shuffle.
0555 """
0556 sc = SparkContext._active_spark_context
0557 jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
0558 return Column(jc)
0559
0560
0561 @since(1.6)
0562 def monotonically_increasing_id():
0563 """A column that generates monotonically increasing 64-bit integers.
0564
0565 The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
0566 The current implementation puts the partition ID in the upper 31 bits, and the record number
0567 within each partition in the lower 33 bits. The assumption is that the data frame has
0568 less than 1 billion partitions, and each partition has less than 8 billion records.
0569
0570 .. note:: The function is non-deterministic because its result depends on partition IDs.
0571
0572 As an example, consider a :class:`DataFrame` with two partitions, each with 3 records.
0573 This expression would return the following IDs:
0574 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
0575
0576 >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1'])
0577 >>> df0.select(monotonically_increasing_id().alias('id')).collect()
0578 [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)]
0579 """
0580 sc = SparkContext._active_spark_context
0581 return Column(sc._jvm.functions.monotonically_increasing_id())
0582
0583
0584 @since(1.6)
0585 def nanvl(col1, col2):
0586 """Returns col1 if it is not NaN, or col2 if col1 is NaN.
0587
0588 Both inputs should be floating point columns (:class:`DoubleType` or :class:`FloatType`).
0589
0590 >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b"))
0591 >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect()
0592 [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)]
0593 """
0594 sc = SparkContext._active_spark_context
0595 return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2)))
0596
0597
0598 @ignore_unicode_prefix
0599 @since(1.4)
0600 def rand(seed=None):
0601 """Generates a random column with independent and identically distributed (i.i.d.) samples
0602 uniformly distributed in [0.0, 1.0).
0603
0604 .. note:: The function is non-deterministic in general case.
0605
0606 >>> df.withColumn('rand', rand(seed=42) * 3).collect()
0607 [Row(age=2, name=u'Alice', rand=2.4052597283576684),
0608 Row(age=5, name=u'Bob', rand=2.3913904055683974)]
0609 """
0610 sc = SparkContext._active_spark_context
0611 if seed is not None:
0612 jc = sc._jvm.functions.rand(seed)
0613 else:
0614 jc = sc._jvm.functions.rand()
0615 return Column(jc)
0616
0617
0618 @ignore_unicode_prefix
0619 @since(1.4)
0620 def randn(seed=None):
0621 """Generates a column with independent and identically distributed (i.i.d.) samples from
0622 the standard normal distribution.
0623
0624 .. note:: The function is non-deterministic in general case.
0625
0626 >>> df.withColumn('randn', randn(seed=42)).collect()
0627 [Row(age=2, name=u'Alice', randn=1.1027054481455365),
0628 Row(age=5, name=u'Bob', randn=0.7400395449950132)]
0629 """
0630 sc = SparkContext._active_spark_context
0631 if seed is not None:
0632 jc = sc._jvm.functions.randn(seed)
0633 else:
0634 jc = sc._jvm.functions.randn()
0635 return Column(jc)
0636
0637
0638 @since(1.5)
0639 def round(col, scale=0):
0640 """
0641 Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0
0642 or at integral part when `scale` < 0.
0643
0644 >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect()
0645 [Row(r=3.0)]
0646 """
0647 sc = SparkContext._active_spark_context
0648 return Column(sc._jvm.functions.round(_to_java_column(col), scale))
0649
0650
0651 @since(2.0)
0652 def bround(col, scale=0):
0653 """
0654 Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0
0655 or at integral part when `scale` < 0.
0656
0657 >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect()
0658 [Row(r=2.0)]
0659 """
0660 sc = SparkContext._active_spark_context
0661 return Column(sc._jvm.functions.bround(_to_java_column(col), scale))
0662
0663
0664 @since(1.5)
0665 def shiftLeft(col, numBits):
0666 """Shift the given value numBits left.
0667
0668 >>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
0669 [Row(r=42)]
0670 """
0671 sc = SparkContext._active_spark_context
0672 return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits))
0673
0674
0675 @since(1.5)
0676 def shiftRight(col, numBits):
0677 """(Signed) shift the given value numBits right.
0678
0679 >>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
0680 [Row(r=21)]
0681 """
0682 sc = SparkContext._active_spark_context
0683 jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
0684 return Column(jc)
0685
0686
0687 @since(1.5)
0688 def shiftRightUnsigned(col, numBits):
0689 """Unsigned shift the given value numBits right.
0690
0691 >>> df = spark.createDataFrame([(-42,)], ['a'])
0692 >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()
0693 [Row(r=9223372036854775787)]
0694 """
0695 sc = SparkContext._active_spark_context
0696 jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
0697 return Column(jc)
0698
0699
0700 @since(1.6)
0701 def spark_partition_id():
0702 """A column for partition ID.
0703
0704 .. note:: This is indeterministic because it depends on data partitioning and task scheduling.
0705
0706 >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect()
0707 [Row(pid=0), Row(pid=0)]
0708 """
0709 sc = SparkContext._active_spark_context
0710 return Column(sc._jvm.functions.spark_partition_id())
0711
0712
0713 @since(1.5)
0714 def expr(str):
0715 """Parses the expression string into the column that it represents
0716
0717 >>> df.select(expr("length(name)")).collect()
0718 [Row(length(name)=5), Row(length(name)=3)]
0719 """
0720 sc = SparkContext._active_spark_context
0721 return Column(sc._jvm.functions.expr(str))
0722
0723
0724 @ignore_unicode_prefix
0725 @since(1.4)
0726 def struct(*cols):
0727 """Creates a new struct column.
0728
0729 :param cols: list of column names (string) or list of :class:`Column` expressions
0730
0731 >>> df.select(struct('age', 'name').alias("struct")).collect()
0732 [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
0733 >>> df.select(struct([df.age, df.name]).alias("struct")).collect()
0734 [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
0735 """
0736 sc = SparkContext._active_spark_context
0737 if len(cols) == 1 and isinstance(cols[0], (list, set)):
0738 cols = cols[0]
0739 jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column))
0740 return Column(jc)
0741
0742
0743 @since(1.5)
0744 def greatest(*cols):
0745 """
0746 Returns the greatest value of the list of column names, skipping null values.
0747 This function takes at least 2 parameters. It will return null iff all parameters are null.
0748
0749 >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
0750 >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
0751 [Row(greatest=4)]
0752 """
0753 if len(cols) < 2:
0754 raise ValueError("greatest should take at least two columns")
0755 sc = SparkContext._active_spark_context
0756 return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column)))
0757
0758
0759 @since(1.5)
0760 def least(*cols):
0761 """
0762 Returns the least value of the list of column names, skipping null values.
0763 This function takes at least 2 parameters. It will return null iff all parameters are null.
0764
0765 >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
0766 >>> df.select(least(df.a, df.b, df.c).alias("least")).collect()
0767 [Row(least=1)]
0768 """
0769 if len(cols) < 2:
0770 raise ValueError("least should take at least two columns")
0771 sc = SparkContext._active_spark_context
0772 return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column)))
0773
0774
0775 @since(1.4)
0776 def when(condition, value):
0777 """Evaluates a list of conditions and returns one of multiple possible result expressions.
0778 If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
0779
0780 :param condition: a boolean :class:`Column` expression.
0781 :param value: a literal value, or a :class:`Column` expression.
0782
0783 >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
0784 [Row(age=3), Row(age=4)]
0785
0786 >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
0787 [Row(age=3), Row(age=None)]
0788 """
0789 sc = SparkContext._active_spark_context
0790 if not isinstance(condition, Column):
0791 raise TypeError("condition should be a Column")
0792 v = value._jc if isinstance(value, Column) else value
0793 jc = sc._jvm.functions.when(condition._jc, v)
0794 return Column(jc)
0795
0796
0797 @since(1.5)
0798 def log(arg1, arg2=None):
0799 """Returns the first argument-based logarithm of the second argument.
0800
0801 If there is only one argument, then this takes the natural logarithm of the argument.
0802
0803 >>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).collect()
0804 ['0.30102', '0.69897']
0805
0806 >>> df.select(log(df.age).alias('e')).rdd.map(lambda l: str(l.e)[:7]).collect()
0807 ['0.69314', '1.60943']
0808 """
0809 sc = SparkContext._active_spark_context
0810 if arg2 is None:
0811 jc = sc._jvm.functions.log(_to_java_column(arg1))
0812 else:
0813 jc = sc._jvm.functions.log(arg1, _to_java_column(arg2))
0814 return Column(jc)
0815
0816
0817 @since(1.5)
0818 def log2(col):
0819 """Returns the base-2 logarithm of the argument.
0820
0821 >>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect()
0822 [Row(log2=2.0)]
0823 """
0824 sc = SparkContext._active_spark_context
0825 return Column(sc._jvm.functions.log2(_to_java_column(col)))
0826
0827
0828 @since(1.5)
0829 @ignore_unicode_prefix
0830 def conv(col, fromBase, toBase):
0831 """
0832 Convert a number in a string column from one base to another.
0833
0834 >>> df = spark.createDataFrame([("010101",)], ['n'])
0835 >>> df.select(conv(df.n, 2, 16).alias('hex')).collect()
0836 [Row(hex=u'15')]
0837 """
0838 sc = SparkContext._active_spark_context
0839 return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase))
0840
0841
0842 @since(1.5)
0843 def factorial(col):
0844 """
0845 Computes the factorial of the given value.
0846
0847 >>> df = spark.createDataFrame([(5,)], ['n'])
0848 >>> df.select(factorial(df.n).alias('f')).collect()
0849 [Row(f=120)]
0850 """
0851 sc = SparkContext._active_spark_context
0852 return Column(sc._jvm.functions.factorial(_to_java_column(col)))
0853
0854
0855
0856
0857 @since(1.4)
0858 def lag(col, offset=1, default=None):
0859 """
0860 Window function: returns the value that is `offset` rows before the current row, and
0861 `defaultValue` if there is less than `offset` rows before the current row. For example,
0862 an `offset` of one will return the previous row at any given point in the window partition.
0863
0864 This is equivalent to the LAG function in SQL.
0865
0866 :param col: name of column or expression
0867 :param offset: number of row to extend
0868 :param default: default value
0869 """
0870 sc = SparkContext._active_spark_context
0871 return Column(sc._jvm.functions.lag(_to_java_column(col), offset, default))
0872
0873
0874 @since(1.4)
0875 def lead(col, offset=1, default=None):
0876 """
0877 Window function: returns the value that is `offset` rows after the current row, and
0878 `defaultValue` if there is less than `offset` rows after the current row. For example,
0879 an `offset` of one will return the next row at any given point in the window partition.
0880
0881 This is equivalent to the LEAD function in SQL.
0882
0883 :param col: name of column or expression
0884 :param offset: number of row to extend
0885 :param default: default value
0886 """
0887 sc = SparkContext._active_spark_context
0888 return Column(sc._jvm.functions.lead(_to_java_column(col), offset, default))
0889
0890
0891 @since(1.4)
0892 def ntile(n):
0893 """
0894 Window function: returns the ntile group id (from 1 to `n` inclusive)
0895 in an ordered window partition. For example, if `n` is 4, the first
0896 quarter of the rows will get value 1, the second quarter will get 2,
0897 the third quarter will get 3, and the last quarter will get 4.
0898
0899 This is equivalent to the NTILE function in SQL.
0900
0901 :param n: an integer
0902 """
0903 sc = SparkContext._active_spark_context
0904 return Column(sc._jvm.functions.ntile(int(n)))
0905
0906
0907
0908
0909 @since(1.5)
0910 def current_date():
0911 """
0912 Returns the current date as a :class:`DateType` column.
0913 """
0914 sc = SparkContext._active_spark_context
0915 return Column(sc._jvm.functions.current_date())
0916
0917
0918 def current_timestamp():
0919 """
0920 Returns the current timestamp as a :class:`TimestampType` column.
0921 """
0922 sc = SparkContext._active_spark_context
0923 return Column(sc._jvm.functions.current_timestamp())
0924
0925
0926 @ignore_unicode_prefix
0927 @since(1.5)
0928 def date_format(date, format):
0929 """
0930 Converts a date/timestamp/string to a value of string in the format specified by the date
0931 format given by the second argument.
0932
0933 A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All
0934 pattern letters of `datetime pattern`_. can be used.
0935
0936 .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
0937 .. note:: Use when ever possible specialized functions like `year`. These benefit from a
0938 specialized implementation.
0939
0940 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
0941 >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect()
0942 [Row(date=u'04/08/2015')]
0943 """
0944 sc = SparkContext._active_spark_context
0945 return Column(sc._jvm.functions.date_format(_to_java_column(date), format))
0946
0947
0948 @since(1.5)
0949 def year(col):
0950 """
0951 Extract the year of a given date as integer.
0952
0953 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
0954 >>> df.select(year('dt').alias('year')).collect()
0955 [Row(year=2015)]
0956 """
0957 sc = SparkContext._active_spark_context
0958 return Column(sc._jvm.functions.year(_to_java_column(col)))
0959
0960
0961 @since(1.5)
0962 def quarter(col):
0963 """
0964 Extract the quarter of a given date as integer.
0965
0966 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
0967 >>> df.select(quarter('dt').alias('quarter')).collect()
0968 [Row(quarter=2)]
0969 """
0970 sc = SparkContext._active_spark_context
0971 return Column(sc._jvm.functions.quarter(_to_java_column(col)))
0972
0973
0974 @since(1.5)
0975 def month(col):
0976 """
0977 Extract the month of a given date as integer.
0978
0979 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
0980 >>> df.select(month('dt').alias('month')).collect()
0981 [Row(month=4)]
0982 """
0983 sc = SparkContext._active_spark_context
0984 return Column(sc._jvm.functions.month(_to_java_column(col)))
0985
0986
0987 @since(2.3)
0988 def dayofweek(col):
0989 """
0990 Extract the day of the week of a given date as integer.
0991
0992 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
0993 >>> df.select(dayofweek('dt').alias('day')).collect()
0994 [Row(day=4)]
0995 """
0996 sc = SparkContext._active_spark_context
0997 return Column(sc._jvm.functions.dayofweek(_to_java_column(col)))
0998
0999
1000 @since(1.5)
1001 def dayofmonth(col):
1002 """
1003 Extract the day of the month of a given date as integer.
1004
1005 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1006 >>> df.select(dayofmonth('dt').alias('day')).collect()
1007 [Row(day=8)]
1008 """
1009 sc = SparkContext._active_spark_context
1010 return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))
1011
1012
1013 @since(1.5)
1014 def dayofyear(col):
1015 """
1016 Extract the day of the year of a given date as integer.
1017
1018 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1019 >>> df.select(dayofyear('dt').alias('day')).collect()
1020 [Row(day=98)]
1021 """
1022 sc = SparkContext._active_spark_context
1023 return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))
1024
1025
1026 @since(1.5)
1027 def hour(col):
1028 """
1029 Extract the hours of a given date as integer.
1030
1031 >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])
1032 >>> df.select(hour('ts').alias('hour')).collect()
1033 [Row(hour=13)]
1034 """
1035 sc = SparkContext._active_spark_context
1036 return Column(sc._jvm.functions.hour(_to_java_column(col)))
1037
1038
1039 @since(1.5)
1040 def minute(col):
1041 """
1042 Extract the minutes of a given date as integer.
1043
1044 >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])
1045 >>> df.select(minute('ts').alias('minute')).collect()
1046 [Row(minute=8)]
1047 """
1048 sc = SparkContext._active_spark_context
1049 return Column(sc._jvm.functions.minute(_to_java_column(col)))
1050
1051
1052 @since(1.5)
1053 def second(col):
1054 """
1055 Extract the seconds of a given date as integer.
1056
1057 >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])
1058 >>> df.select(second('ts').alias('second')).collect()
1059 [Row(second=15)]
1060 """
1061 sc = SparkContext._active_spark_context
1062 return Column(sc._jvm.functions.second(_to_java_column(col)))
1063
1064
1065 @since(1.5)
1066 def weekofyear(col):
1067 """
1068 Extract the week number of a given date as integer.
1069
1070 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1071 >>> df.select(weekofyear(df.dt).alias('week')).collect()
1072 [Row(week=15)]
1073 """
1074 sc = SparkContext._active_spark_context
1075 return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))
1076
1077
1078 @since(1.5)
1079 def date_add(start, days):
1080 """
1081 Returns the date that is `days` days after `start`
1082
1083 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1084 >>> df.select(date_add(df.dt, 1).alias('next_date')).collect()
1085 [Row(next_date=datetime.date(2015, 4, 9))]
1086 """
1087 sc = SparkContext._active_spark_context
1088 return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
1089
1090
1091 @since(1.5)
1092 def date_sub(start, days):
1093 """
1094 Returns the date that is `days` days before `start`
1095
1096 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1097 >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect()
1098 [Row(prev_date=datetime.date(2015, 4, 7))]
1099 """
1100 sc = SparkContext._active_spark_context
1101 return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
1102
1103
1104 @since(1.5)
1105 def datediff(end, start):
1106 """
1107 Returns the number of days from `start` to `end`.
1108
1109 >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
1110 >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()
1111 [Row(diff=32)]
1112 """
1113 sc = SparkContext._active_spark_context
1114 return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start)))
1115
1116
1117 @since(1.5)
1118 def add_months(start, months):
1119 """
1120 Returns the date that is `months` months after `start`
1121
1122 >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1123 >>> df.select(add_months(df.dt, 1).alias('next_month')).collect()
1124 [Row(next_month=datetime.date(2015, 5, 8))]
1125 """
1126 sc = SparkContext._active_spark_context
1127 return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
1128
1129
1130 @since(1.5)
1131 def months_between(date1, date2, roundOff=True):
1132 """
1133 Returns number of months between dates date1 and date2.
1134 If date1 is later than date2, then the result is positive.
1135 If date1 and date2 are on the same day of month, or both are the last day of month,
1136 returns an integer (time of day will be ignored).
1137 The result is rounded off to 8 digits unless `roundOff` is set to `False`.
1138
1139 >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2'])
1140 >>> df.select(months_between(df.date1, df.date2).alias('months')).collect()
1141 [Row(months=3.94959677)]
1142 >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect()
1143 [Row(months=3.9495967741935485)]
1144 """
1145 sc = SparkContext._active_spark_context
1146 return Column(sc._jvm.functions.months_between(
1147 _to_java_column(date1), _to_java_column(date2), roundOff))
1148
1149
1150 @since(2.2)
1151 def to_date(col, format=None):
1152 """Converts a :class:`Column` into :class:`pyspark.sql.types.DateType`
1153 using the optionally specified format. Specify formats according to `datetime pattern`_.
1154 By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format
1155 is omitted. Equivalent to ``col.cast("date")``.
1156
1157 >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
1158 >>> df.select(to_date(df.t).alias('date')).collect()
1159 [Row(date=datetime.date(1997, 2, 28))]
1160
1161 >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
1162 >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect()
1163 [Row(date=datetime.date(1997, 2, 28))]
1164 """
1165 sc = SparkContext._active_spark_context
1166 if format is None:
1167 jc = sc._jvm.functions.to_date(_to_java_column(col))
1168 else:
1169 jc = sc._jvm.functions.to_date(_to_java_column(col), format)
1170 return Column(jc)
1171
1172
1173 @since(2.2)
1174 def to_timestamp(col, format=None):
1175 """Converts a :class:`Column` into :class:`pyspark.sql.types.TimestampType`
1176 using the optionally specified format. Specify formats according to `datetime pattern`_.
1177 By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format
1178 is omitted. Equivalent to ``col.cast("timestamp")``.
1179
1180 >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
1181 >>> df.select(to_timestamp(df.t).alias('dt')).collect()
1182 [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
1183
1184 >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
1185 >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect()
1186 [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
1187 """
1188 sc = SparkContext._active_spark_context
1189 if format is None:
1190 jc = sc._jvm.functions.to_timestamp(_to_java_column(col))
1191 else:
1192 jc = sc._jvm.functions.to_timestamp(_to_java_column(col), format)
1193 return Column(jc)
1194
1195
1196 @since(1.5)
1197 def trunc(date, format):
1198 """
1199 Returns date truncated to the unit specified by the format.
1200
1201 :param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm'
1202
1203 >>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
1204 >>> df.select(trunc(df.d, 'year').alias('year')).collect()
1205 [Row(year=datetime.date(1997, 1, 1))]
1206 >>> df.select(trunc(df.d, 'mon').alias('month')).collect()
1207 [Row(month=datetime.date(1997, 2, 1))]
1208 """
1209 sc = SparkContext._active_spark_context
1210 return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
1211
1212
1213 @since(2.3)
1214 def date_trunc(format, timestamp):
1215 """
1216 Returns timestamp truncated to the unit specified by the format.
1217
1218 :param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm',
1219 'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter'
1220
1221 >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
1222 >>> df.select(date_trunc('year', df.t).alias('year')).collect()
1223 [Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
1224 >>> df.select(date_trunc('mon', df.t).alias('month')).collect()
1225 [Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
1226 """
1227 sc = SparkContext._active_spark_context
1228 return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp)))
1229
1230
1231 @since(1.5)
1232 def next_day(date, dayOfWeek):
1233 """
1234 Returns the first date which is later than the value of the date column.
1235
1236 Day of the week parameter is case insensitive, and accepts:
1237 "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
1238
1239 >>> df = spark.createDataFrame([('2015-07-27',)], ['d'])
1240 >>> df.select(next_day(df.d, 'Sun').alias('date')).collect()
1241 [Row(date=datetime.date(2015, 8, 2))]
1242 """
1243 sc = SparkContext._active_spark_context
1244 return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek))
1245
1246
1247 @since(1.5)
1248 def last_day(date):
1249 """
1250 Returns the last day of the month which the given date belongs to.
1251
1252 >>> df = spark.createDataFrame([('1997-02-10',)], ['d'])
1253 >>> df.select(last_day(df.d).alias('date')).collect()
1254 [Row(date=datetime.date(1997, 2, 28))]
1255 """
1256 sc = SparkContext._active_spark_context
1257 return Column(sc._jvm.functions.last_day(_to_java_column(date)))
1258
1259
1260 @ignore_unicode_prefix
1261 @since(1.5)
1262 def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"):
1263 """
1264 Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
1265 representing the timestamp of that moment in the current system time zone in the given
1266 format.
1267
1268 >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
1269 >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time'])
1270 >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect()
1271 [Row(ts=u'2015-04-08 00:00:00')]
1272 >>> spark.conf.unset("spark.sql.session.timeZone")
1273 """
1274 sc = SparkContext._active_spark_context
1275 return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format))
1276
1277
1278 @since(1.5)
1279 def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'):
1280 """
1281 Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default)
1282 to Unix time stamp (in seconds), using the default timezone and the default
1283 locale, return null if fail.
1284
1285 if `timestamp` is None, then it returns current timestamp.
1286
1287 >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
1288 >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt'])
1289 >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect()
1290 [Row(unix_time=1428476400)]
1291 >>> spark.conf.unset("spark.sql.session.timeZone")
1292 """
1293 sc = SparkContext._active_spark_context
1294 if timestamp is None:
1295 return Column(sc._jvm.functions.unix_timestamp())
1296 return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format))
1297
1298
1299 @since(1.5)
1300 def from_utc_timestamp(timestamp, tz):
1301 """
1302 This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
1303 takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and
1304 renders that timestamp as a timestamp in the given time zone.
1305
1306 However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not
1307 timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to
1308 the given timezone.
1309
1310 This function may return confusing result if the input is a string with timezone, e.g.
1311 '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp
1312 according to the timezone in the string, and finally display the result by converting the
1313 timestamp to string according to the session local timezone.
1314
1315 :param timestamp: the column that contains timestamps
1316 :param tz: A string detailing the time zone ID that the input should be adjusted to. It should
1317 be in the format of either region-based zone IDs or zone offsets. Region IDs must
1318 have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in
1319 the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are
1320 supported as aliases of '+00:00'. Other short names are not recommended to use
1321 because they can be ambiguous.
1322
1323 .. versionchanged:: 2.4
1324 `tz` can take a :class:`Column` containing timezone ID strings.
1325
1326 >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])
1327 >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect()
1328 [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))]
1329 >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect()
1330 [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))]
1331 """
1332 sc = SparkContext._active_spark_context
1333 if isinstance(tz, Column):
1334 tz = _to_java_column(tz)
1335 return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz))
1336
1337
1338 @since(1.5)
1339 def to_utc_timestamp(timestamp, tz):
1340 """
1341 This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
1342 takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given
1343 timezone, and renders that timestamp as a timestamp in UTC.
1344
1345 However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not
1346 timezone-agnostic. So in Spark this function just shift the timestamp value from the given
1347 timezone to UTC timezone.
1348
1349 This function may return confusing result if the input is a string with timezone, e.g.
1350 '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp
1351 according to the timezone in the string, and finally display the result by converting the
1352 timestamp to string according to the session local timezone.
1353
1354 :param timestamp: the column that contains timestamps
1355 :param tz: A string detailing the time zone ID that the input should be adjusted to. It should
1356 be in the format of either region-based zone IDs or zone offsets. Region IDs must
1357 have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in
1358 the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are
1359 supported as aliases of '+00:00'. Other short names are not recommended to use
1360 because they can be ambiguous.
1361
1362 .. versionchanged:: 2.4
1363 `tz` can take a :class:`Column` containing timezone ID strings.
1364
1365 >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])
1366 >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect()
1367 [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))]
1368 >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect()
1369 [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))]
1370 """
1371 sc = SparkContext._active_spark_context
1372 if isinstance(tz, Column):
1373 tz = _to_java_column(tz)
1374 return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
1375
1376
1377 @since(2.0)
1378 @ignore_unicode_prefix
1379 def window(timeColumn, windowDuration, slideDuration=None, startTime=None):
1380 """Bucketize rows into one or more time windows given a timestamp specifying column. Window
1381 starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
1382 [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
1383 the order of months are not supported.
1384
1385 The time column must be of :class:`pyspark.sql.types.TimestampType`.
1386
1387 Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
1388 interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
1389 If the ``slideDuration`` is not provided, the windows will be tumbling windows.
1390
1391 The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start
1392 window intervals. For example, in order to have hourly tumbling windows that start 15 minutes
1393 past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`.
1394
1395 The output column will be a struct called 'window' by default with the nested columns 'start'
1396 and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`.
1397
1398 >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val")
1399 >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum"))
1400 >>> w.select(w.window.start.cast("string").alias("start"),
1401 ... w.window.end.cast("string").alias("end"), "sum").collect()
1402 [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)]
1403 """
1404 def check_string_field(field, fieldName):
1405 if not field or type(field) is not str:
1406 raise TypeError("%s should be provided as a string" % fieldName)
1407
1408 sc = SparkContext._active_spark_context
1409 time_col = _to_java_column(timeColumn)
1410 check_string_field(windowDuration, "windowDuration")
1411 if slideDuration and startTime:
1412 check_string_field(slideDuration, "slideDuration")
1413 check_string_field(startTime, "startTime")
1414 res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime)
1415 elif slideDuration:
1416 check_string_field(slideDuration, "slideDuration")
1417 res = sc._jvm.functions.window(time_col, windowDuration, slideDuration)
1418 elif startTime:
1419 check_string_field(startTime, "startTime")
1420 res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime)
1421 else:
1422 res = sc._jvm.functions.window(time_col, windowDuration)
1423 return Column(res)
1424
1425
1426
1427
1428 @since(1.5)
1429 @ignore_unicode_prefix
1430 def crc32(col):
1431 """
1432 Calculates the cyclic redundancy check value (CRC32) of a binary column and
1433 returns the value as a bigint.
1434
1435 >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect()
1436 [Row(crc32=2743272264)]
1437 """
1438 sc = SparkContext._active_spark_context
1439 return Column(sc._jvm.functions.crc32(_to_java_column(col)))
1440
1441
1442 @ignore_unicode_prefix
1443 @since(1.5)
1444 def md5(col):
1445 """Calculates the MD5 digest and returns the value as a 32 character hex string.
1446
1447 >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
1448 [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
1449 """
1450 sc = SparkContext._active_spark_context
1451 jc = sc._jvm.functions.md5(_to_java_column(col))
1452 return Column(jc)
1453
1454
1455 @ignore_unicode_prefix
1456 @since(1.5)
1457 def sha1(col):
1458 """Returns the hex string result of SHA-1.
1459
1460 >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
1461 [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
1462 """
1463 sc = SparkContext._active_spark_context
1464 jc = sc._jvm.functions.sha1(_to_java_column(col))
1465 return Column(jc)
1466
1467
1468 @ignore_unicode_prefix
1469 @since(1.5)
1470 def sha2(col, numBits):
1471 """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
1472 and SHA-512). The numBits indicates the desired bit length of the result, which must have a
1473 value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
1474
1475 >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
1476 >>> digests[0]
1477 Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
1478 >>> digests[1]
1479 Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
1480 """
1481 sc = SparkContext._active_spark_context
1482 jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
1483 return Column(jc)
1484
1485
1486 @since(2.0)
1487 def hash(*cols):
1488 """Calculates the hash code of given columns, and returns the result as an int column.
1489
1490 >>> spark.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect()
1491 [Row(hash=-757602832)]
1492 """
1493 sc = SparkContext._active_spark_context
1494 jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column))
1495 return Column(jc)
1496
1497
1498 @since(3.0)
1499 def xxhash64(*cols):
1500 """Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm,
1501 and returns the result as a long column.
1502
1503 >>> spark.createDataFrame([('ABC',)], ['a']).select(xxhash64('a').alias('hash')).collect()
1504 [Row(hash=4105715581806190027)]
1505 """
1506 sc = SparkContext._active_spark_context
1507 jc = sc._jvm.functions.xxhash64(_to_seq(sc, cols, _to_java_column))
1508 return Column(jc)
1509
1510
1511
1512
1513 _string_functions = {
1514 'upper': 'Converts a string expression to upper case.',
1515 'lower': 'Converts a string expression to lower case.',
1516 'ascii': 'Computes the numeric value of the first character of the string column.',
1517 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
1518 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
1519 'ltrim': 'Trim the spaces from left end for the specified string value.',
1520 'rtrim': 'Trim the spaces from right end for the specified string value.',
1521 'trim': 'Trim the spaces from both ends for the specified string column.',
1522 }
1523
1524
1525 for _name, _doc in _string_functions.items():
1526 globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc))
1527 del _name, _doc
1528
1529
1530 @since(1.5)
1531 @ignore_unicode_prefix
1532 def concat_ws(sep, *cols):
1533 """
1534 Concatenates multiple input string columns together into a single string column,
1535 using the given separator.
1536
1537 >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
1538 >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
1539 [Row(s=u'abcd-123')]
1540 """
1541 sc = SparkContext._active_spark_context
1542 return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column)))
1543
1544
1545 @since(1.5)
1546 def decode(col, charset):
1547 """
1548 Computes the first argument into a string from a binary using the provided character set
1549 (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
1550 """
1551 sc = SparkContext._active_spark_context
1552 return Column(sc._jvm.functions.decode(_to_java_column(col), charset))
1553
1554
1555 @since(1.5)
1556 def encode(col, charset):
1557 """
1558 Computes the first argument into a binary from a string using the provided character set
1559 (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
1560 """
1561 sc = SparkContext._active_spark_context
1562 return Column(sc._jvm.functions.encode(_to_java_column(col), charset))
1563
1564
1565 @ignore_unicode_prefix
1566 @since(1.5)
1567 def format_number(col, d):
1568 """
1569 Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places
1570 with HALF_EVEN round mode, and returns the result as a string.
1571
1572 :param col: the column name of the numeric value to be formatted
1573 :param d: the N decimal places
1574
1575 >>> spark.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
1576 [Row(v=u'5.0000')]
1577 """
1578 sc = SparkContext._active_spark_context
1579 return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
1580
1581
1582 @ignore_unicode_prefix
1583 @since(1.5)
1584 def format_string(format, *cols):
1585 """
1586 Formats the arguments in printf-style and returns the result as a string column.
1587
1588 :param format: string that can contain embedded format tags and used as result column's value
1589 :param cols: list of column names (string) or list of :class:`Column` expressions to
1590 be used in formatting
1591
1592 >>> df = spark.createDataFrame([(5, "hello")], ['a', 'b'])
1593 >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()
1594 [Row(v=u'5 hello')]
1595 """
1596 sc = SparkContext._active_spark_context
1597 return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column)))
1598
1599
1600 @since(1.5)
1601 def instr(str, substr):
1602 """
1603 Locate the position of the first occurrence of substr column in the given string.
1604 Returns null if either of the arguments are null.
1605
1606 .. note:: The position is not zero based, but 1 based index. Returns 0 if substr
1607 could not be found in str.
1608
1609 >>> df = spark.createDataFrame([('abcd',)], ['s',])
1610 >>> df.select(instr(df.s, 'b').alias('s')).collect()
1611 [Row(s=2)]
1612 """
1613 sc = SparkContext._active_spark_context
1614 return Column(sc._jvm.functions.instr(_to_java_column(str), substr))
1615
1616
1617 @since(3.0)
1618 def overlay(src, replace, pos, len=-1):
1619 """
1620 Overlay the specified portion of `src` with `replace`,
1621 starting from byte position `pos` of `src` and proceeding for `len` bytes.
1622
1623 >>> df = spark.createDataFrame([("SPARK_SQL", "CORE")], ("x", "y"))
1624 >>> df.select(overlay("x", "y", 7).alias("overlayed")).show()
1625 +----------+
1626 | overlayed|
1627 +----------+
1628 |SPARK_CORE|
1629 +----------+
1630 """
1631 if not isinstance(pos, (int, str, Column)):
1632 raise TypeError(
1633 "pos should be an integer or a Column / column name, got {}".format(type(pos)))
1634 if len is not None and not isinstance(len, (int, str, Column)):
1635 raise TypeError(
1636 "len should be an integer or a Column / column name, got {}".format(type(len)))
1637
1638 pos = _create_column_from_literal(pos) if isinstance(pos, int) else _to_java_column(pos)
1639 len = _create_column_from_literal(len) if isinstance(len, int) else _to_java_column(len)
1640
1641 sc = SparkContext._active_spark_context
1642
1643 return Column(sc._jvm.functions.overlay(
1644 _to_java_column(src),
1645 _to_java_column(replace),
1646 pos,
1647 len
1648 ))
1649
1650
1651 @since(1.5)
1652 @ignore_unicode_prefix
1653 def substring(str, pos, len):
1654 """
1655 Substring starts at `pos` and is of length `len` when str is String type or
1656 returns the slice of byte array that starts at `pos` in byte and is of length `len`
1657 when str is Binary type.
1658
1659 .. note:: The position is not zero based, but 1 based index.
1660
1661 >>> df = spark.createDataFrame([('abcd',)], ['s',])
1662 >>> df.select(substring(df.s, 1, 2).alias('s')).collect()
1663 [Row(s=u'ab')]
1664 """
1665 sc = SparkContext._active_spark_context
1666 return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len))
1667
1668
1669 @since(1.5)
1670 @ignore_unicode_prefix
1671 def substring_index(str, delim, count):
1672 """
1673 Returns the substring from string str before count occurrences of the delimiter delim.
1674 If count is positive, everything the left of the final delimiter (counting from left) is
1675 returned. If count is negative, every to the right of the final delimiter (counting from the
1676 right) is returned. substring_index performs a case-sensitive match when searching for delim.
1677
1678 >>> df = spark.createDataFrame([('a.b.c.d',)], ['s'])
1679 >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
1680 [Row(s=u'a.b')]
1681 >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
1682 [Row(s=u'b.c.d')]
1683 """
1684 sc = SparkContext._active_spark_context
1685 return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
1686
1687
1688 @ignore_unicode_prefix
1689 @since(1.5)
1690 def levenshtein(left, right):
1691 """Computes the Levenshtein distance of the two given strings.
1692
1693 >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
1694 >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
1695 [Row(d=3)]
1696 """
1697 sc = SparkContext._active_spark_context
1698 jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
1699 return Column(jc)
1700
1701
1702 @since(1.5)
1703 def locate(substr, str, pos=1):
1704 """
1705 Locate the position of the first occurrence of substr in a string column, after position pos.
1706
1707 .. note:: The position is not zero based, but 1 based index. Returns 0 if substr
1708 could not be found in str.
1709
1710 :param substr: a string
1711 :param str: a Column of :class:`pyspark.sql.types.StringType`
1712 :param pos: start position (zero based)
1713
1714 >>> df = spark.createDataFrame([('abcd',)], ['s',])
1715 >>> df.select(locate('b', df.s, 1).alias('s')).collect()
1716 [Row(s=2)]
1717 """
1718 sc = SparkContext._active_spark_context
1719 return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos))
1720
1721
1722 @since(1.5)
1723 @ignore_unicode_prefix
1724 def lpad(col, len, pad):
1725 """
1726 Left-pad the string column to width `len` with `pad`.
1727
1728 >>> df = spark.createDataFrame([('abcd',)], ['s',])
1729 >>> df.select(lpad(df.s, 6, '#').alias('s')).collect()
1730 [Row(s=u'##abcd')]
1731 """
1732 sc = SparkContext._active_spark_context
1733 return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad))
1734
1735
1736 @since(1.5)
1737 @ignore_unicode_prefix
1738 def rpad(col, len, pad):
1739 """
1740 Right-pad the string column to width `len` with `pad`.
1741
1742 >>> df = spark.createDataFrame([('abcd',)], ['s',])
1743 >>> df.select(rpad(df.s, 6, '#').alias('s')).collect()
1744 [Row(s=u'abcd##')]
1745 """
1746 sc = SparkContext._active_spark_context
1747 return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad))
1748
1749
1750 @since(1.5)
1751 @ignore_unicode_prefix
1752 def repeat(col, n):
1753 """
1754 Repeats a string column n times, and returns it as a new string column.
1755
1756 >>> df = spark.createDataFrame([('ab',)], ['s',])
1757 >>> df.select(repeat(df.s, 3).alias('s')).collect()
1758 [Row(s=u'ababab')]
1759 """
1760 sc = SparkContext._active_spark_context
1761 return Column(sc._jvm.functions.repeat(_to_java_column(col), n))
1762
1763
1764 @since(1.5)
1765 @ignore_unicode_prefix
1766 def split(str, pattern, limit=-1):
1767 """
1768 Splits str around matches of the given pattern.
1769
1770 :param str: a string expression to split
1771 :param pattern: a string representing a regular expression. The regex string should be
1772 a Java regular expression.
1773 :param limit: an integer which controls the number of times `pattern` is applied.
1774
1775 * ``limit > 0``: The resulting array's length will not be more than `limit`, and the
1776 resulting array's last entry will contain all input beyond the last
1777 matched pattern.
1778 * ``limit <= 0``: `pattern` will be applied as many times as possible, and the resulting
1779 array can be of any size.
1780
1781 .. versionchanged:: 3.0
1782 `split` now takes an optional `limit` field. If not provided, default limit value is -1.
1783
1784 >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
1785 >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()
1786 [Row(s=[u'one', u'twoBthreeC'])]
1787 >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()
1788 [Row(s=[u'one', u'two', u'three', u''])]
1789 """
1790 sc = SparkContext._active_spark_context
1791 return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit))
1792
1793
1794 @ignore_unicode_prefix
1795 @since(1.5)
1796 def regexp_extract(str, pattern, idx):
1797 r"""Extract a specific group matched by a Java regex, from the specified string column.
1798 If the regex did not match, or the specified group did not match, an empty string is returned.
1799
1800 >>> df = spark.createDataFrame([('100-200',)], ['str'])
1801 >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect()
1802 [Row(d=u'100')]
1803 >>> df = spark.createDataFrame([('foo',)], ['str'])
1804 >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect()
1805 [Row(d=u'')]
1806 >>> df = spark.createDataFrame([('aaaac',)], ['str'])
1807 >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()
1808 [Row(d=u'')]
1809 """
1810 sc = SparkContext._active_spark_context
1811 jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
1812 return Column(jc)
1813
1814
1815 @ignore_unicode_prefix
1816 @since(1.5)
1817 def regexp_replace(str, pattern, replacement):
1818 r"""Replace all substrings of the specified string value that match regexp with rep.
1819
1820 >>> df = spark.createDataFrame([('100-200',)], ['str'])
1821 >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect()
1822 [Row(d=u'-----')]
1823 """
1824 sc = SparkContext._active_spark_context
1825 jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
1826 return Column(jc)
1827
1828
1829 @ignore_unicode_prefix
1830 @since(1.5)
1831 def initcap(col):
1832 """Translate the first letter of each word to upper case in the sentence.
1833
1834 >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect()
1835 [Row(v=u'Ab Cd')]
1836 """
1837 sc = SparkContext._active_spark_context
1838 return Column(sc._jvm.functions.initcap(_to_java_column(col)))
1839
1840
1841 @since(1.5)
1842 @ignore_unicode_prefix
1843 def soundex(col):
1844 """
1845 Returns the SoundEx encoding for a string
1846
1847 >>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name'])
1848 >>> df.select(soundex(df.name).alias("soundex")).collect()
1849 [Row(soundex=u'P362'), Row(soundex=u'U612')]
1850 """
1851 sc = SparkContext._active_spark_context
1852 return Column(sc._jvm.functions.soundex(_to_java_column(col)))
1853
1854
1855 @ignore_unicode_prefix
1856 @since(1.5)
1857 def bin(col):
1858 """Returns the string representation of the binary value of the given column.
1859
1860 >>> df.select(bin(df.age).alias('c')).collect()
1861 [Row(c=u'10'), Row(c=u'101')]
1862 """
1863 sc = SparkContext._active_spark_context
1864 jc = sc._jvm.functions.bin(_to_java_column(col))
1865 return Column(jc)
1866
1867
1868 @ignore_unicode_prefix
1869 @since(1.5)
1870 def hex(col):
1871 """Computes hex value of the given column, which could be :class:`pyspark.sql.types.StringType`,
1872 :class:`pyspark.sql.types.BinaryType`, :class:`pyspark.sql.types.IntegerType` or
1873 :class:`pyspark.sql.types.LongType`.
1874
1875 >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
1876 [Row(hex(a)=u'414243', hex(b)=u'3')]
1877 """
1878 sc = SparkContext._active_spark_context
1879 jc = sc._jvm.functions.hex(_to_java_column(col))
1880 return Column(jc)
1881
1882
1883 @ignore_unicode_prefix
1884 @since(1.5)
1885 def unhex(col):
1886 """Inverse of hex. Interprets each pair of characters as a hexadecimal number
1887 and converts to the byte representation of number.
1888
1889 >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
1890 [Row(unhex(a)=bytearray(b'ABC'))]
1891 """
1892 sc = SparkContext._active_spark_context
1893 return Column(sc._jvm.functions.unhex(_to_java_column(col)))
1894
1895
1896 @ignore_unicode_prefix
1897 @since(1.5)
1898 def length(col):
1899 """Computes the character length of string data or number of bytes of binary data.
1900 The length of character data includes the trailing spaces. The length of binary data
1901 includes binary zeros.
1902
1903 >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect()
1904 [Row(length=4)]
1905 """
1906 sc = SparkContext._active_spark_context
1907 return Column(sc._jvm.functions.length(_to_java_column(col)))
1908
1909
1910 @ignore_unicode_prefix
1911 @since(1.5)
1912 def translate(srcCol, matching, replace):
1913 """A function translate any character in the `srcCol` by a character in `matching`.
1914 The characters in `replace` is corresponding to the characters in `matching`.
1915 The translate will happen when any character in the string matching with the character
1916 in the `matching`.
1917
1918 >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123") \\
1919 ... .alias('r')).collect()
1920 [Row(r=u'1a2s3ae')]
1921 """
1922 sc = SparkContext._active_spark_context
1923 return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace))
1924
1925
1926
1927
1928 @ignore_unicode_prefix
1929 @since(2.0)
1930 def create_map(*cols):
1931 """Creates a new map column.
1932
1933 :param cols: list of column names (string) or list of :class:`Column` expressions that are
1934 grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...).
1935
1936 >>> df.select(create_map('name', 'age').alias("map")).collect()
1937 [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
1938 >>> df.select(create_map([df.name, df.age]).alias("map")).collect()
1939 [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
1940 """
1941 sc = SparkContext._active_spark_context
1942 if len(cols) == 1 and isinstance(cols[0], (list, set)):
1943 cols = cols[0]
1944 jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column))
1945 return Column(jc)
1946
1947
1948 @since(2.4)
1949 def map_from_arrays(col1, col2):
1950 """Creates a new map from two arrays.
1951
1952 :param col1: name of column containing a set of keys. All elements should not be null
1953 :param col2: name of column containing a set of values
1954
1955 >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
1956 >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()
1957 +----------------+
1958 | map|
1959 +----------------+
1960 |[2 -> a, 5 -> b]|
1961 +----------------+
1962 """
1963 sc = SparkContext._active_spark_context
1964 return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2)))
1965
1966
1967 @since(1.4)
1968 def array(*cols):
1969 """Creates a new array column.
1970
1971 :param cols: list of column names (string) or list of :class:`Column` expressions that have
1972 the same data type.
1973
1974 >>> df.select(array('age', 'age').alias("arr")).collect()
1975 [Row(arr=[2, 2]), Row(arr=[5, 5])]
1976 >>> df.select(array([df.age, df.age]).alias("arr")).collect()
1977 [Row(arr=[2, 2]), Row(arr=[5, 5])]
1978 """
1979 sc = SparkContext._active_spark_context
1980 if len(cols) == 1 and isinstance(cols[0], (list, set)):
1981 cols = cols[0]
1982 jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
1983 return Column(jc)
1984
1985
1986 @since(1.5)
1987 def array_contains(col, value):
1988 """
1989 Collection function: returns null if the array is null, true if the array contains the
1990 given value, and false otherwise.
1991
1992 :param col: name of column containing array
1993 :param value: value or column to check for in array
1994
1995 >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
1996 >>> df.select(array_contains(df.data, "a")).collect()
1997 [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
1998 >>> df.select(array_contains(df.data, lit("a"))).collect()
1999 [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
2000 """
2001 sc = SparkContext._active_spark_context
2002 value = value._jc if isinstance(value, Column) else value
2003 return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
2004
2005
2006 @since(2.4)
2007 def arrays_overlap(a1, a2):
2008 """
2009 Collection function: returns true if the arrays contain any common non-null element; if not,
2010 returns null if both the arrays are non-empty and any of them contains a null element; returns
2011 false otherwise.
2012
2013 >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
2014 >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
2015 [Row(overlap=True), Row(overlap=False)]
2016 """
2017 sc = SparkContext._active_spark_context
2018 return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))
2019
2020
2021 @since(2.4)
2022 def slice(x, start, length):
2023 """
2024 Collection function: returns an array containing all the elements in `x` from index `start`
2025 (array indices start at 1, or from the end if `start` is negative) with the specified `length`.
2026
2027 :param x: the array to be sliced
2028 :param start: the starting index
2029 :param length: the length of the slice
2030
2031 >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
2032 >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
2033 [Row(sliced=[2, 3]), Row(sliced=[5])]
2034 """
2035 sc = SparkContext._active_spark_context
2036 return Column(sc._jvm.functions.slice(_to_java_column(x), start, length))
2037
2038
2039 @ignore_unicode_prefix
2040 @since(2.4)
2041 def array_join(col, delimiter, null_replacement=None):
2042 """
2043 Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
2044 `null_replacement` if set, otherwise they are ignored.
2045
2046 >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
2047 >>> df.select(array_join(df.data, ",").alias("joined")).collect()
2048 [Row(joined=u'a,b,c'), Row(joined=u'a')]
2049 >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
2050 [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')]
2051 """
2052 sc = SparkContext._active_spark_context
2053 if null_replacement is None:
2054 return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter))
2055 else:
2056 return Column(sc._jvm.functions.array_join(
2057 _to_java_column(col), delimiter, null_replacement))
2058
2059
2060 @since(1.5)
2061 @ignore_unicode_prefix
2062 def concat(*cols):
2063 """
2064 Concatenates multiple input columns together into a single column.
2065 The function works with strings, binary and compatible array columns.
2066
2067 >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
2068 >>> df.select(concat(df.s, df.d).alias('s')).collect()
2069 [Row(s=u'abcd123')]
2070
2071 >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
2072 >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
2073 [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
2074 """
2075 sc = SparkContext._active_spark_context
2076 return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
2077
2078
2079 @since(2.4)
2080 def array_position(col, value):
2081 """
2082 Collection function: Locates the position of the first occurrence of the given value
2083 in the given array. Returns null if either of the arguments are null.
2084
2085 .. note:: The position is not zero based, but 1 based index. Returns 0 if the given
2086 value could not be found in the array.
2087
2088 >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
2089 >>> df.select(array_position(df.data, "a")).collect()
2090 [Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
2091 """
2092 sc = SparkContext._active_spark_context
2093 return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
2094
2095
2096 @ignore_unicode_prefix
2097 @since(2.4)
2098 def element_at(col, extraction):
2099 """
2100 Collection function: Returns element of array at given index in extraction if col is array.
2101 Returns value for the given key in extraction if col is map.
2102
2103 :param col: name of column containing array or map
2104 :param extraction: index to check for in array or key to check for in map
2105
2106 .. note:: The position is not zero based, but 1 based index.
2107
2108 >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
2109 >>> df.select(element_at(df.data, 1)).collect()
2110 [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]
2111
2112 >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
2113 >>> df.select(element_at(df.data, lit("a"))).collect()
2114 [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
2115 """
2116 sc = SparkContext._active_spark_context
2117 return Column(sc._jvm.functions.element_at(
2118 _to_java_column(col), lit(extraction)._jc))
2119
2120
2121 @since(2.4)
2122 def array_remove(col, element):
2123 """
2124 Collection function: Remove all elements that equal to element from the given array.
2125
2126 :param col: name of column containing array
2127 :param element: element to be removed from the array
2128
2129 >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])
2130 >>> df.select(array_remove(df.data, 1)).collect()
2131 [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])]
2132 """
2133 sc = SparkContext._active_spark_context
2134 return Column(sc._jvm.functions.array_remove(_to_java_column(col), element))
2135
2136
2137 @since(2.4)
2138 def array_distinct(col):
2139 """
2140 Collection function: removes duplicate values from the array.
2141
2142 :param col: name of column or expression
2143
2144 >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])
2145 >>> df.select(array_distinct(df.data)).collect()
2146 [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
2147 """
2148 sc = SparkContext._active_spark_context
2149 return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
2150
2151
2152 @ignore_unicode_prefix
2153 @since(2.4)
2154 def array_intersect(col1, col2):
2155 """
2156 Collection function: returns an array of the elements in the intersection of col1 and col2,
2157 without duplicates.
2158
2159 :param col1: name of column containing array
2160 :param col2: name of column containing array
2161
2162 >>> from pyspark.sql import Row
2163 >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
2164 >>> df.select(array_intersect(df.c1, df.c2)).collect()
2165 [Row(array_intersect(c1, c2)=[u'a', u'c'])]
2166 """
2167 sc = SparkContext._active_spark_context
2168 return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2)))
2169
2170
2171 @ignore_unicode_prefix
2172 @since(2.4)
2173 def array_union(col1, col2):
2174 """
2175 Collection function: returns an array of the elements in the union of col1 and col2,
2176 without duplicates.
2177
2178 :param col1: name of column containing array
2179 :param col2: name of column containing array
2180
2181 >>> from pyspark.sql import Row
2182 >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
2183 >>> df.select(array_union(df.c1, df.c2)).collect()
2184 [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])]
2185 """
2186 sc = SparkContext._active_spark_context
2187 return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))
2188
2189
2190 @ignore_unicode_prefix
2191 @since(2.4)
2192 def array_except(col1, col2):
2193 """
2194 Collection function: returns an array of the elements in col1 but not in col2,
2195 without duplicates.
2196
2197 :param col1: name of column containing array
2198 :param col2: name of column containing array
2199
2200 >>> from pyspark.sql import Row
2201 >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
2202 >>> df.select(array_except(df.c1, df.c2)).collect()
2203 [Row(array_except(c1, c2)=[u'b'])]
2204 """
2205 sc = SparkContext._active_spark_context
2206 return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2)))
2207
2208
2209 @since(1.4)
2210 def explode(col):
2211 """
2212 Returns a new row for each element in the given array or map.
2213 Uses the default column name `col` for elements in the array and
2214 `key` and `value` for elements in the map unless specified otherwise.
2215
2216 >>> from pyspark.sql import Row
2217 >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
2218 >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
2219 [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
2220
2221 >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
2222 +---+-----+
2223 |key|value|
2224 +---+-----+
2225 | a| b|
2226 +---+-----+
2227 """
2228 sc = SparkContext._active_spark_context
2229 jc = sc._jvm.functions.explode(_to_java_column(col))
2230 return Column(jc)
2231
2232
2233 @since(2.1)
2234 def posexplode(col):
2235 """
2236 Returns a new row for each element with position in the given array or map.
2237 Uses the default column name `pos` for position, and `col` for elements in the
2238 array and `key` and `value` for elements in the map unless specified otherwise.
2239
2240 >>> from pyspark.sql import Row
2241 >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
2242 >>> eDF.select(posexplode(eDF.intlist)).collect()
2243 [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
2244
2245 >>> eDF.select(posexplode(eDF.mapfield)).show()
2246 +---+---+-----+
2247 |pos|key|value|
2248 +---+---+-----+
2249 | 0| a| b|
2250 +---+---+-----+
2251 """
2252 sc = SparkContext._active_spark_context
2253 jc = sc._jvm.functions.posexplode(_to_java_column(col))
2254 return Column(jc)
2255
2256
2257 @since(2.3)
2258 def explode_outer(col):
2259 """
2260 Returns a new row for each element in the given array or map.
2261 Unlike explode, if the array/map is null or empty then null is produced.
2262 Uses the default column name `col` for elements in the array and
2263 `key` and `value` for elements in the map unless specified otherwise.
2264
2265 >>> df = spark.createDataFrame(
2266 ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
2267 ... ("id", "an_array", "a_map")
2268 ... )
2269 >>> df.select("id", "an_array", explode_outer("a_map")).show()
2270 +---+----------+----+-----+
2271 | id| an_array| key|value|
2272 +---+----------+----+-----+
2273 | 1|[foo, bar]| x| 1.0|
2274 | 2| []|null| null|
2275 | 3| null|null| null|
2276 +---+----------+----+-----+
2277
2278 >>> df.select("id", "a_map", explode_outer("an_array")).show()
2279 +---+----------+----+
2280 | id| a_map| col|
2281 +---+----------+----+
2282 | 1|[x -> 1.0]| foo|
2283 | 1|[x -> 1.0]| bar|
2284 | 2| []|null|
2285 | 3| null|null|
2286 +---+----------+----+
2287 """
2288 sc = SparkContext._active_spark_context
2289 jc = sc._jvm.functions.explode_outer(_to_java_column(col))
2290 return Column(jc)
2291
2292
2293 @since(2.3)
2294 def posexplode_outer(col):
2295 """
2296 Returns a new row for each element with position in the given array or map.
2297 Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
2298 Uses the default column name `pos` for position, and `col` for elements in the
2299 array and `key` and `value` for elements in the map unless specified otherwise.
2300
2301 >>> df = spark.createDataFrame(
2302 ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
2303 ... ("id", "an_array", "a_map")
2304 ... )
2305 >>> df.select("id", "an_array", posexplode_outer("a_map")).show()
2306 +---+----------+----+----+-----+
2307 | id| an_array| pos| key|value|
2308 +---+----------+----+----+-----+
2309 | 1|[foo, bar]| 0| x| 1.0|
2310 | 2| []|null|null| null|
2311 | 3| null|null|null| null|
2312 +---+----------+----+----+-----+
2313 >>> df.select("id", "a_map", posexplode_outer("an_array")).show()
2314 +---+----------+----+----+
2315 | id| a_map| pos| col|
2316 +---+----------+----+----+
2317 | 1|[x -> 1.0]| 0| foo|
2318 | 1|[x -> 1.0]| 1| bar|
2319 | 2| []|null|null|
2320 | 3| null|null|null|
2321 +---+----------+----+----+
2322 """
2323 sc = SparkContext._active_spark_context
2324 jc = sc._jvm.functions.posexplode_outer(_to_java_column(col))
2325 return Column(jc)
2326
2327
2328 @ignore_unicode_prefix
2329 @since(1.6)
2330 def get_json_object(col, path):
2331 """
2332 Extracts json object from a json string based on json path specified, and returns json string
2333 of the extracted json object. It will return null if the input json string is invalid.
2334
2335 :param col: string column in json format
2336 :param path: path to the json object to extract
2337
2338 >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')]
2339 >>> df = spark.createDataFrame(data, ("key", "jstring"))
2340 >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \\
2341 ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect()
2342 [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)]
2343 """
2344 sc = SparkContext._active_spark_context
2345 jc = sc._jvm.functions.get_json_object(_to_java_column(col), path)
2346 return Column(jc)
2347
2348
2349 @ignore_unicode_prefix
2350 @since(1.6)
2351 def json_tuple(col, *fields):
2352 """Creates a new row for a json column according to the given field names.
2353
2354 :param col: string column in json format
2355 :param fields: list of fields to extract
2356
2357 >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')]
2358 >>> df = spark.createDataFrame(data, ("key", "jstring"))
2359 >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()
2360 [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)]
2361 """
2362 sc = SparkContext._active_spark_context
2363 jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields))
2364 return Column(jc)
2365
2366
2367 @ignore_unicode_prefix
2368 @since(2.1)
2369 def from_json(col, schema, options={}):
2370 """
2371 Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
2372 as keys type, :class:`StructType` or :class:`ArrayType` with
2373 the specified schema. Returns `null`, in the case of an unparseable string.
2374
2375 :param col: string column in json format
2376 :param schema: a StructType or ArrayType of StructType to use when parsing the json column.
2377 :param options: options to control parsing. accepts the same options as the json datasource
2378
2379 .. note:: Since Spark 2.3, the DDL-formatted string or a JSON format string is also
2380 supported for ``schema``.
2381
2382 >>> from pyspark.sql.types import *
2383 >>> data = [(1, '''{"a": 1}''')]
2384 >>> schema = StructType([StructField("a", IntegerType())])
2385 >>> df = spark.createDataFrame(data, ("key", "value"))
2386 >>> df.select(from_json(df.value, schema).alias("json")).collect()
2387 [Row(json=Row(a=1))]
2388 >>> df.select(from_json(df.value, "a INT").alias("json")).collect()
2389 [Row(json=Row(a=1))]
2390 >>> df.select(from_json(df.value, "MAP<STRING,INT>").alias("json")).collect()
2391 [Row(json={u'a': 1})]
2392 >>> data = [(1, '''[{"a": 1}]''')]
2393 >>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
2394 >>> df = spark.createDataFrame(data, ("key", "value"))
2395 >>> df.select(from_json(df.value, schema).alias("json")).collect()
2396 [Row(json=[Row(a=1)])]
2397 >>> schema = schema_of_json(lit('''{"a": 0}'''))
2398 >>> df.select(from_json(df.value, schema).alias("json")).collect()
2399 [Row(json=Row(a=None))]
2400 >>> data = [(1, '''[1, 2, 3]''')]
2401 >>> schema = ArrayType(IntegerType())
2402 >>> df = spark.createDataFrame(data, ("key", "value"))
2403 >>> df.select(from_json(df.value, schema).alias("json")).collect()
2404 [Row(json=[1, 2, 3])]
2405 """
2406
2407 sc = SparkContext._active_spark_context
2408 if isinstance(schema, DataType):
2409 schema = schema.json()
2410 elif isinstance(schema, Column):
2411 schema = _to_java_column(schema)
2412 jc = sc._jvm.functions.from_json(_to_java_column(col), schema, _options_to_str(options))
2413 return Column(jc)
2414
2415
2416 @ignore_unicode_prefix
2417 @since(2.1)
2418 def to_json(col, options={}):
2419 """
2420 Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType`
2421 into a JSON string. Throws an exception, in the case of an unsupported type.
2422
2423 :param col: name of column containing a struct, an array or a map.
2424 :param options: options to control converting. accepts the same options as the JSON datasource.
2425 Additionally the function supports the `pretty` option which enables
2426 pretty JSON generation.
2427
2428 >>> from pyspark.sql import Row
2429 >>> from pyspark.sql.types import *
2430 >>> data = [(1, Row(name='Alice', age=2))]
2431 >>> df = spark.createDataFrame(data, ("key", "value"))
2432 >>> df.select(to_json(df.value).alias("json")).collect()
2433 [Row(json=u'{"age":2,"name":"Alice"}')]
2434 >>> data = [(1, [Row(name='Alice', age=2), Row(name='Bob', age=3)])]
2435 >>> df = spark.createDataFrame(data, ("key", "value"))
2436 >>> df.select(to_json(df.value).alias("json")).collect()
2437 [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')]
2438 >>> data = [(1, {"name": "Alice"})]
2439 >>> df = spark.createDataFrame(data, ("key", "value"))
2440 >>> df.select(to_json(df.value).alias("json")).collect()
2441 [Row(json=u'{"name":"Alice"}')]
2442 >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])]
2443 >>> df = spark.createDataFrame(data, ("key", "value"))
2444 >>> df.select(to_json(df.value).alias("json")).collect()
2445 [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')]
2446 >>> data = [(1, ["Alice", "Bob"])]
2447 >>> df = spark.createDataFrame(data, ("key", "value"))
2448 >>> df.select(to_json(df.value).alias("json")).collect()
2449 [Row(json=u'["Alice","Bob"]')]
2450 """
2451
2452 sc = SparkContext._active_spark_context
2453 jc = sc._jvm.functions.to_json(_to_java_column(col), _options_to_str(options))
2454 return Column(jc)
2455
2456
2457 @ignore_unicode_prefix
2458 @since(2.4)
2459 def schema_of_json(json, options={}):
2460 """
2461 Parses a JSON string and infers its schema in DDL format.
2462
2463 :param json: a JSON string or a string literal containing a JSON string.
2464 :param options: options to control parsing. accepts the same options as the JSON datasource
2465
2466 .. versionchanged:: 3.0
2467 It accepts `options` parameter to control schema inferring.
2468
2469 >>> df = spark.range(1)
2470 >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()
2471 [Row(json=u'struct<a:bigint>')]
2472 >>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'})
2473 >>> df.select(schema.alias("json")).collect()
2474 [Row(json=u'struct<a:bigint>')]
2475 """
2476 if isinstance(json, basestring):
2477 col = _create_column_from_literal(json)
2478 elif isinstance(json, Column):
2479 col = _to_java_column(json)
2480 else:
2481 raise TypeError("schema argument should be a column or string")
2482
2483 sc = SparkContext._active_spark_context
2484 jc = sc._jvm.functions.schema_of_json(col, _options_to_str(options))
2485 return Column(jc)
2486
2487
2488 @ignore_unicode_prefix
2489 @since(3.0)
2490 def schema_of_csv(csv, options={}):
2491 """
2492 Parses a CSV string and infers its schema in DDL format.
2493
2494 :param col: a CSV string or a string literal containing a CSV string.
2495 :param options: options to control parsing. accepts the same options as the CSV datasource
2496
2497 >>> df = spark.range(1)
2498 >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect()
2499 [Row(csv=u'struct<_c0:int,_c1:string>')]
2500 >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()
2501 [Row(csv=u'struct<_c0:int,_c1:string>')]
2502 """
2503 if isinstance(csv, basestring):
2504 col = _create_column_from_literal(csv)
2505 elif isinstance(csv, Column):
2506 col = _to_java_column(csv)
2507 else:
2508 raise TypeError("schema argument should be a column or string")
2509
2510 sc = SparkContext._active_spark_context
2511 jc = sc._jvm.functions.schema_of_csv(col, _options_to_str(options))
2512 return Column(jc)
2513
2514
2515 @ignore_unicode_prefix
2516 @since(3.0)
2517 def to_csv(col, options={}):
2518 """
2519 Converts a column containing a :class:`StructType` into a CSV string.
2520 Throws an exception, in the case of an unsupported type.
2521
2522 :param col: name of column containing a struct.
2523 :param options: options to control converting. accepts the same options as the CSV datasource.
2524
2525 >>> from pyspark.sql import Row
2526 >>> data = [(1, Row(name='Alice', age=2))]
2527 >>> df = spark.createDataFrame(data, ("key", "value"))
2528 >>> df.select(to_csv(df.value).alias("csv")).collect()
2529 [Row(csv=u'2,Alice')]
2530 """
2531
2532 sc = SparkContext._active_spark_context
2533 jc = sc._jvm.functions.to_csv(_to_java_column(col), _options_to_str(options))
2534 return Column(jc)
2535
2536
2537 @since(1.5)
2538 def size(col):
2539 """
2540 Collection function: returns the length of the array or map stored in the column.
2541
2542 :param col: name of column or expression
2543
2544 >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
2545 >>> df.select(size(df.data)).collect()
2546 [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
2547 """
2548 sc = SparkContext._active_spark_context
2549 return Column(sc._jvm.functions.size(_to_java_column(col)))
2550
2551
2552 @since(2.4)
2553 def array_min(col):
2554 """
2555 Collection function: returns the minimum value of the array.
2556
2557 :param col: name of column or expression
2558
2559 >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
2560 >>> df.select(array_min(df.data).alias('min')).collect()
2561 [Row(min=1), Row(min=-1)]
2562 """
2563 sc = SparkContext._active_spark_context
2564 return Column(sc._jvm.functions.array_min(_to_java_column(col)))
2565
2566
2567 @since(2.4)
2568 def array_max(col):
2569 """
2570 Collection function: returns the maximum value of the array.
2571
2572 :param col: name of column or expression
2573
2574 >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
2575 >>> df.select(array_max(df.data).alias('max')).collect()
2576 [Row(max=3), Row(max=10)]
2577 """
2578 sc = SparkContext._active_spark_context
2579 return Column(sc._jvm.functions.array_max(_to_java_column(col)))
2580
2581
2582 @since(1.5)
2583 def sort_array(col, asc=True):
2584 """
2585 Collection function: sorts the input array in ascending or descending order according
2586 to the natural ordering of the array elements. Null elements will be placed at the beginning
2587 of the returned array in ascending order or at the end of the returned array in descending
2588 order.
2589
2590 :param col: name of column or expression
2591
2592 >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
2593 >>> df.select(sort_array(df.data).alias('r')).collect()
2594 [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]
2595 >>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
2596 [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
2597 """
2598 sc = SparkContext._active_spark_context
2599 return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
2600
2601
2602 @since(2.4)
2603 def array_sort(col):
2604 """
2605 Collection function: sorts the input array in ascending order. The elements of the input array
2606 must be orderable. Null elements will be placed at the end of the returned array.
2607
2608 :param col: name of column or expression
2609
2610 >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
2611 >>> df.select(array_sort(df.data).alias('r')).collect()
2612 [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
2613 """
2614 sc = SparkContext._active_spark_context
2615 return Column(sc._jvm.functions.array_sort(_to_java_column(col)))
2616
2617
2618 @since(2.4)
2619 def shuffle(col):
2620 """
2621 Collection function: Generates a random permutation of the given array.
2622
2623 .. note:: The function is non-deterministic.
2624
2625 :param col: name of column or expression
2626
2627 >>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data'])
2628 >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP
2629 [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])]
2630 """
2631 sc = SparkContext._active_spark_context
2632 return Column(sc._jvm.functions.shuffle(_to_java_column(col)))
2633
2634
2635 @since(1.5)
2636 @ignore_unicode_prefix
2637 def reverse(col):
2638 """
2639 Collection function: returns a reversed string or an array with reverse order of elements.
2640
2641 :param col: name of column or expression
2642
2643 >>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
2644 >>> df.select(reverse(df.data).alias('s')).collect()
2645 [Row(s=u'LQS krapS')]
2646 >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
2647 >>> df.select(reverse(df.data).alias('r')).collect()
2648 [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
2649 """
2650 sc = SparkContext._active_spark_context
2651 return Column(sc._jvm.functions.reverse(_to_java_column(col)))
2652
2653
2654 @since(2.4)
2655 def flatten(col):
2656 """
2657 Collection function: creates a single array from an array of arrays.
2658 If a structure of nested arrays is deeper than two levels,
2659 only one level of nesting is removed.
2660
2661 :param col: name of column or expression
2662
2663 >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
2664 >>> df.select(flatten(df.data).alias('r')).collect()
2665 [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
2666 """
2667 sc = SparkContext._active_spark_context
2668 return Column(sc._jvm.functions.flatten(_to_java_column(col)))
2669
2670
2671 @since(2.3)
2672 def map_keys(col):
2673 """
2674 Collection function: Returns an unordered array containing the keys of the map.
2675
2676 :param col: name of column or expression
2677
2678 >>> from pyspark.sql.functions import map_keys
2679 >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
2680 >>> df.select(map_keys("data").alias("keys")).show()
2681 +------+
2682 | keys|
2683 +------+
2684 |[1, 2]|
2685 +------+
2686 """
2687 sc = SparkContext._active_spark_context
2688 return Column(sc._jvm.functions.map_keys(_to_java_column(col)))
2689
2690
2691 @since(2.3)
2692 def map_values(col):
2693 """
2694 Collection function: Returns an unordered array containing the values of the map.
2695
2696 :param col: name of column or expression
2697
2698 >>> from pyspark.sql.functions import map_values
2699 >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
2700 >>> df.select(map_values("data").alias("values")).show()
2701 +------+
2702 |values|
2703 +------+
2704 |[a, b]|
2705 +------+
2706 """
2707 sc = SparkContext._active_spark_context
2708 return Column(sc._jvm.functions.map_values(_to_java_column(col)))
2709
2710
2711 @since(3.0)
2712 def map_entries(col):
2713 """
2714 Collection function: Returns an unordered array of all entries in the given map.
2715
2716 :param col: name of column or expression
2717
2718 >>> from pyspark.sql.functions import map_entries
2719 >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
2720 >>> df.select(map_entries("data").alias("entries")).show()
2721 +----------------+
2722 | entries|
2723 +----------------+
2724 |[[1, a], [2, b]]|
2725 +----------------+
2726 """
2727 sc = SparkContext._active_spark_context
2728 return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
2729
2730
2731 @since(2.4)
2732 def map_from_entries(col):
2733 """
2734 Collection function: Returns a map created from the given array of entries.
2735
2736 :param col: name of column or expression
2737
2738 >>> from pyspark.sql.functions import map_from_entries
2739 >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
2740 >>> df.select(map_from_entries("data").alias("map")).show()
2741 +----------------+
2742 | map|
2743 +----------------+
2744 |[1 -> a, 2 -> b]|
2745 +----------------+
2746 """
2747 sc = SparkContext._active_spark_context
2748 return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))
2749
2750
2751 @ignore_unicode_prefix
2752 @since(2.4)
2753 def array_repeat(col, count):
2754 """
2755 Collection function: creates an array containing a column repeated count times.
2756
2757 >>> df = spark.createDataFrame([('ab',)], ['data'])
2758 >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
2759 [Row(r=[u'ab', u'ab', u'ab'])]
2760 """
2761 sc = SparkContext._active_spark_context
2762 return Column(sc._jvm.functions.array_repeat(
2763 _to_java_column(col),
2764 _to_java_column(count) if isinstance(count, Column) else count
2765 ))
2766
2767
2768 @since(2.4)
2769 def arrays_zip(*cols):
2770 """
2771 Collection function: Returns a merged array of structs in which the N-th struct contains all
2772 N-th values of input arrays.
2773
2774 :param cols: columns of arrays to be merged.
2775
2776 >>> from pyspark.sql.functions import arrays_zip
2777 >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])
2778 >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()
2779 [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])]
2780 """
2781 sc = SparkContext._active_spark_context
2782 return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
2783
2784
2785 @since(2.4)
2786 def map_concat(*cols):
2787 """Returns the union of all the given maps.
2788
2789 :param cols: list of column names (string) or list of :class:`Column` expressions
2790
2791 >>> from pyspark.sql.functions import map_concat
2792 >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2")
2793 >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
2794 +------------------------+
2795 |map3 |
2796 +------------------------+
2797 |[1 -> a, 2 -> b, 3 -> c]|
2798 +------------------------+
2799 """
2800 sc = SparkContext._active_spark_context
2801 if len(cols) == 1 and isinstance(cols[0], (list, set)):
2802 cols = cols[0]
2803 jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column))
2804 return Column(jc)
2805
2806
2807 @since(2.4)
2808 def sequence(start, stop, step=None):
2809 """
2810 Generate a sequence of integers from `start` to `stop`, incrementing by `step`.
2811 If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`,
2812 otherwise -1.
2813
2814 >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2'))
2815 >>> df1.select(sequence('C1', 'C2').alias('r')).collect()
2816 [Row(r=[-2, -1, 0, 1, 2])]
2817 >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3'))
2818 >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect()
2819 [Row(r=[4, 2, 0, -2, -4])]
2820 """
2821 sc = SparkContext._active_spark_context
2822 if step is None:
2823 return Column(sc._jvm.functions.sequence(_to_java_column(start), _to_java_column(stop)))
2824 else:
2825 return Column(sc._jvm.functions.sequence(
2826 _to_java_column(start), _to_java_column(stop), _to_java_column(step)))
2827
2828
2829 @ignore_unicode_prefix
2830 @since(3.0)
2831 def from_csv(col, schema, options={}):
2832 """
2833 Parses a column containing a CSV string to a row with the specified schema.
2834 Returns `null`, in the case of an unparseable string.
2835
2836 :param col: string column in CSV format
2837 :param schema: a string with schema in DDL format to use when parsing the CSV column.
2838 :param options: options to control parsing. accepts the same options as the CSV datasource
2839
2840 >>> data = [("1,2,3",)]
2841 >>> df = spark.createDataFrame(data, ("value",))
2842 >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect()
2843 [Row(csv=Row(a=1, b=2, c=3))]
2844 >>> value = data[0][0]
2845 >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()
2846 [Row(csv=Row(_c0=1, _c1=2, _c2=3))]
2847 >>> data = [(" abc",)]
2848 >>> df = spark.createDataFrame(data, ("value",))
2849 >>> options = {'ignoreLeadingWhiteSpace': True}
2850 >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect()
2851 [Row(csv=Row(s=u'abc'))]
2852 """
2853
2854 sc = SparkContext._active_spark_context
2855 if isinstance(schema, basestring):
2856 schema = _create_column_from_literal(schema)
2857 elif isinstance(schema, Column):
2858 schema = _to_java_column(schema)
2859 else:
2860 raise TypeError("schema argument should be a column or string")
2861
2862 jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, _options_to_str(options))
2863 return Column(jc)
2864
2865
2866
2867
2868 @since(1.3)
2869 def udf(f=None, returnType=StringType()):
2870 """Creates a user defined function (UDF).
2871
2872 .. note:: The user-defined functions are considered deterministic by default. Due to
2873 optimization, duplicate invocations may be eliminated or the function may even be invoked
2874 more times than it is present in the query. If your function is not deterministic, call
2875 `asNondeterministic` on the user defined function. E.g.:
2876
2877 >>> from pyspark.sql.types import IntegerType
2878 >>> import random
2879 >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
2880
2881 .. note:: The user-defined functions do not support conditional expressions or short circuiting
2882 in boolean expressions and it ends up with being executed all internally. If the functions
2883 can fail on special rows, the workaround is to incorporate the condition into the functions.
2884
2885 .. note:: The user-defined functions do not take keyword arguments on the calling side.
2886
2887 :param f: python function if used as a standalone function
2888 :param returnType: the return type of the user-defined function. The value can be either a
2889 :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
2890
2891 >>> from pyspark.sql.types import IntegerType
2892 >>> slen = udf(lambda s: len(s), IntegerType())
2893 >>> @udf
2894 ... def to_upper(s):
2895 ... if s is not None:
2896 ... return s.upper()
2897 ...
2898 >>> @udf(returnType=IntegerType())
2899 ... def add_one(x):
2900 ... if x is not None:
2901 ... return x + 1
2902 ...
2903 >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2904 >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).show()
2905 +----------+--------------+------------+
2906 |slen(name)|to_upper(name)|add_one(age)|
2907 +----------+--------------+------------+
2908 | 8| JOHN DOE| 22|
2909 +----------+--------------+------------+
2910 """
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944 if f is None or isinstance(f, (str, DataType)):
2945
2946
2947 return_type = f or returnType
2948 return functools.partial(_create_udf, returnType=return_type,
2949 evalType=PythonEvalType.SQL_BATCHED_UDF)
2950 else:
2951 return _create_udf(f=f, returnType=returnType,
2952 evalType=PythonEvalType.SQL_BATCHED_UDF)
2953
2954
2955 blacklist = ['map', 'since', 'ignore_unicode_prefix']
2956 __all__ = [k for k, v in globals().items()
2957 if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist]
2958 __all__ += ["PandasUDFType"]
2959 __all__.sort()
2960
2961
2962 def _test():
2963 import doctest
2964 from pyspark.sql import Row, SparkSession
2965 import pyspark.sql.functions
2966 globs = pyspark.sql.functions.__dict__.copy()
2967 spark = SparkSession.builder\
2968 .master("local[4]")\
2969 .appName("sql.functions tests")\
2970 .getOrCreate()
2971 sc = spark.sparkContext
2972 globs['sc'] = sc
2973 globs['spark'] = spark
2974 globs['df'] = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)])
2975 (failure_count, test_count) = doctest.testmod(
2976 pyspark.sql.functions, globs=globs,
2977 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
2978 spark.stop()
2979 if failure_count:
2980 sys.exit(-1)
2981
2982
2983 if __name__ == "__main__":
2984 _test()