0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import datetime
0019 import sys
0020
0021 from pyspark.sql import Row
0022 from pyspark.sql.functions import udf, input_file_name
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024
0025
0026 class FunctionsTests(ReusedSQLTestCase):
0027
0028 def test_explode(self):
0029 from pyspark.sql.functions import explode, explode_outer, posexplode_outer
0030 d = [
0031 Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
0032 Row(a=1, intlist=[], mapfield={}),
0033 Row(a=1, intlist=None, mapfield=None),
0034 ]
0035 rdd = self.sc.parallelize(d)
0036 data = self.spark.createDataFrame(rdd)
0037
0038 result = data.select(explode(data.intlist).alias("a")).select("a").collect()
0039 self.assertEqual(result[0][0], 1)
0040 self.assertEqual(result[1][0], 2)
0041 self.assertEqual(result[2][0], 3)
0042
0043 result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
0044 self.assertEqual(result[0][0], "a")
0045 self.assertEqual(result[0][1], "b")
0046
0047 result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
0048 self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
0049
0050 result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
0051 self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
0052
0053 result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
0054 self.assertEqual(result, [1, 2, 3, None, None])
0055
0056 result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
0057 self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
0058
0059 def test_basic_functions(self):
0060 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
0061 df = self.spark.read.json(rdd)
0062 df.count()
0063 df.collect()
0064 df.schema
0065
0066
0067 self.assertFalse(df.is_cached)
0068 df.persist()
0069 df.unpersist(True)
0070 df.cache()
0071 self.assertTrue(df.is_cached)
0072 self.assertEqual(2, df.count())
0073
0074 with self.tempView("temp"):
0075 df.createOrReplaceTempView("temp")
0076 df = self.spark.sql("select foo from temp")
0077 df.count()
0078 df.collect()
0079
0080 def test_corr(self):
0081 import math
0082 df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
0083 corr = df.stat.corr(u"a", "b")
0084 self.assertTrue(abs(corr - 0.95734012) < 1e-6)
0085
0086 def test_sampleby(self):
0087 df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(100)]).toDF()
0088 sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
0089 self.assertTrue(sampled.count() == 35)
0090
0091 def test_cov(self):
0092 df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
0093 cov = df.stat.cov(u"a", "b")
0094 self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
0095
0096 def test_crosstab(self):
0097 df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
0098 ct = df.stat.crosstab(u"a", "b").collect()
0099 ct = sorted(ct, key=lambda x: x[0])
0100 for i, row in enumerate(ct):
0101 self.assertEqual(row[0], str(i))
0102 self.assertTrue(row[1], 1)
0103 self.assertTrue(row[2], 1)
0104
0105 def test_math_functions(self):
0106 df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
0107 from pyspark.sql import functions
0108 import math
0109
0110 def get_values(l):
0111 return [j[0] for j in l]
0112
0113 def assert_close(a, b):
0114 c = get_values(b)
0115 diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
0116 return sum(diff) == len(a)
0117 assert_close([math.cos(i) for i in range(10)],
0118 df.select(functions.cos(df.a)).collect())
0119 assert_close([math.cos(i) for i in range(10)],
0120 df.select(functions.cos("a")).collect())
0121 assert_close([math.sin(i) for i in range(10)],
0122 df.select(functions.sin(df.a)).collect())
0123 assert_close([math.sin(i) for i in range(10)],
0124 df.select(functions.sin(df['a'])).collect())
0125 assert_close([math.pow(i, 2 * i) for i in range(10)],
0126 df.select(functions.pow(df.a, df.b)).collect())
0127 assert_close([math.pow(i, 2) for i in range(10)],
0128 df.select(functions.pow(df.a, 2)).collect())
0129 assert_close([math.pow(i, 2) for i in range(10)],
0130 df.select(functions.pow(df.a, 2.0)).collect())
0131 assert_close([math.hypot(i, 2 * i) for i in range(10)],
0132 df.select(functions.hypot(df.a, df.b)).collect())
0133 assert_close([math.hypot(i, 2 * i) for i in range(10)],
0134 df.select(functions.hypot("a", u"b")).collect())
0135 assert_close([math.hypot(i, 2) for i in range(10)],
0136 df.select(functions.hypot("a", 2)).collect())
0137 assert_close([math.hypot(i, 2) for i in range(10)],
0138 df.select(functions.hypot(df.a, 2)).collect())
0139
0140 def test_rand_functions(self):
0141 df = self.df
0142 from pyspark.sql import functions
0143 rnd = df.select('key', functions.rand()).collect()
0144 for row in rnd:
0145 assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
0146 rndn = df.select('key', functions.randn(5)).collect()
0147 for row in rndn:
0148 assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
0149
0150
0151
0152 rnd1 = df.select('key', functions.rand(0)).collect()
0153 rnd2 = df.select('key', functions.rand(0)).collect()
0154 self.assertEqual(sorted(rnd1), sorted(rnd2))
0155
0156 rndn1 = df.select('key', functions.randn(0)).collect()
0157 rndn2 = df.select('key', functions.randn(0)).collect()
0158 self.assertEqual(sorted(rndn1), sorted(rndn2))
0159
0160 def test_string_functions(self):
0161 from pyspark.sql import functions
0162 from pyspark.sql.functions import col, lit, _string_functions
0163 df = self.spark.createDataFrame([['nick']], schema=['name'])
0164 self.assertRaisesRegexp(
0165 TypeError,
0166 "must be the same type",
0167 lambda: df.select(col('name').substr(0, lit(1))))
0168 if sys.version_info.major == 2:
0169 self.assertRaises(
0170 TypeError,
0171 lambda: df.select(col('name').substr(long(0), long(1))))
0172
0173 for name in _string_functions.keys():
0174 self.assertEqual(
0175 df.select(getattr(functions, name)("name")).first()[0],
0176 df.select(getattr(functions, name)(col("name"))).first()[0])
0177
0178 def test_array_contains_function(self):
0179 from pyspark.sql.functions import array_contains
0180
0181 df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
0182 actual = df.select(array_contains(df.data, "1").alias('b')).collect()
0183 self.assertEqual([Row(b=True), Row(b=False)], actual)
0184
0185 def test_between_function(self):
0186 df = self.sc.parallelize([
0187 Row(a=1, b=2, c=3),
0188 Row(a=2, b=1, c=3),
0189 Row(a=4, b=1, c=4)]).toDF()
0190 self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
0191 df.filter(df.a.between(df.b, df.c)).collect())
0192
0193 def test_dayofweek(self):
0194 from pyspark.sql.functions import dayofweek
0195 dt = datetime.datetime(2017, 11, 6)
0196 df = self.spark.createDataFrame([Row(date=dt)])
0197 row = df.select(dayofweek(df.date)).first()
0198 self.assertEqual(row[0], 2)
0199
0200 def test_expr(self):
0201 from pyspark.sql import functions
0202 row = Row(a="length string", b=75)
0203 df = self.spark.createDataFrame([row])
0204 result = df.select(functions.expr("length(a)")).collect()[0].asDict()
0205 self.assertEqual(13, result["length(a)"])
0206
0207
0208 def test_functions_broadcast(self):
0209 from pyspark.sql.functions import broadcast
0210
0211 df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
0212 df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
0213
0214
0215 plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
0216 self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
0217
0218
0219 plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
0220 self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
0221
0222
0223 broadcast(df1)._jdf.queryExecution().executedPlan()
0224
0225 def test_first_last_ignorenulls(self):
0226 from pyspark.sql import functions
0227 df = self.spark.range(0, 100)
0228 df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
0229 df3 = df2.select(functions.first(df2.id, False).alias('a'),
0230 functions.first(df2.id, True).alias('b'),
0231 functions.last(df2.id, False).alias('c'),
0232 functions.last(df2.id, True).alias('d'))
0233 self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
0234
0235 def test_approxQuantile(self):
0236 df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
0237 for f in ["a", u"a"]:
0238 aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
0239 self.assertTrue(isinstance(aq, list))
0240 self.assertEqual(len(aq), 3)
0241 self.assertTrue(all(isinstance(q, float) for q in aq))
0242 aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
0243 self.assertTrue(isinstance(aqs, list))
0244 self.assertEqual(len(aqs), 2)
0245 self.assertTrue(isinstance(aqs[0], list))
0246 self.assertEqual(len(aqs[0]), 3)
0247 self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
0248 self.assertTrue(isinstance(aqs[1], list))
0249 self.assertEqual(len(aqs[1]), 3)
0250 self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
0251 aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
0252 self.assertTrue(isinstance(aqt, list))
0253 self.assertEqual(len(aqt), 2)
0254 self.assertTrue(isinstance(aqt[0], list))
0255 self.assertEqual(len(aqt[0]), 3)
0256 self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
0257 self.assertTrue(isinstance(aqt[1], list))
0258 self.assertEqual(len(aqt[1]), 3)
0259 self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
0260 self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
0261 self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
0262 self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
0263
0264 def test_sort_with_nulls_order(self):
0265 from pyspark.sql import functions
0266
0267 df = self.spark.createDataFrame(
0268 [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
0269 self.assertEquals(
0270 df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
0271 [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
0272 self.assertEquals(
0273 df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
0274 [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
0275 self.assertEquals(
0276 df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
0277 [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
0278 self.assertEquals(
0279 df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
0280 [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
0281
0282 def test_input_file_name_reset_for_rdd(self):
0283 rdd = self.sc.textFile('python/test_support/hello/hello.txt').map(lambda x: {'data': x})
0284 df = self.spark.createDataFrame(rdd, "data STRING")
0285 df.select(input_file_name().alias('file')).collect()
0286
0287 non_file_df = self.spark.range(100).select(input_file_name())
0288
0289 results = non_file_df.collect()
0290 self.assertTrue(len(results) == 100)
0291
0292
0293
0294 for result in results:
0295 self.assertEqual(result[0], '')
0296
0297 def test_array_repeat(self):
0298 from pyspark.sql.functions import array_repeat, lit
0299
0300 df = self.spark.range(1)
0301
0302 self.assertEquals(
0303 df.select(array_repeat("id", 3)).toDF("val").collect(),
0304 df.select(array_repeat("id", lit(3))).toDF("val").collect(),
0305 )
0306
0307 def test_input_file_name_udf(self):
0308 df = self.spark.read.text('python/test_support/hello/hello.txt')
0309 df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file'))
0310 file_name = df.collect()[0].file
0311 self.assertTrue("python/test_support/hello/hello.txt" in file_name)
0312
0313 def test_overlay(self):
0314 from pyspark.sql.functions import col, lit, overlay
0315 from itertools import chain
0316 import re
0317
0318 actual = list(chain.from_iterable([
0319 re.findall("(overlay\\(.*\\))", str(x)) for x in [
0320 overlay(col("foo"), col("bar"), 1),
0321 overlay("x", "y", 3),
0322 overlay(col("x"), col("y"), 1, 3),
0323 overlay("x", "y", 2, 5),
0324 overlay("x", "y", lit(11)),
0325 overlay("x", "y", lit(2), lit(5)),
0326 ]
0327 ]))
0328
0329 expected = [
0330 "overlay(foo, bar, 1, -1)",
0331 "overlay(x, y, 3, -1)",
0332 "overlay(x, y, 1, 3)",
0333 "overlay(x, y, 2, 5)",
0334 "overlay(x, y, 11, -1)",
0335 "overlay(x, y, 2, 5)",
0336 ]
0337
0338 self.assertListEqual(actual, expected)
0339
0340
0341 if __name__ == "__main__":
0342 import unittest
0343 from pyspark.sql.tests.test_functions import *
0344
0345 try:
0346 import xmlrunner
0347 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0348 except ImportError:
0349 testRunner = None
0350 unittest.main(testRunner=testRunner, verbosity=2)