Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import datetime
0019 import sys
0020 
0021 from pyspark.sql import Row
0022 from pyspark.sql.functions import udf, input_file_name
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024 
0025 
0026 class FunctionsTests(ReusedSQLTestCase):
0027 
0028     def test_explode(self):
0029         from pyspark.sql.functions import explode, explode_outer, posexplode_outer
0030         d = [
0031             Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
0032             Row(a=1, intlist=[], mapfield={}),
0033             Row(a=1, intlist=None, mapfield=None),
0034         ]
0035         rdd = self.sc.parallelize(d)
0036         data = self.spark.createDataFrame(rdd)
0037 
0038         result = data.select(explode(data.intlist).alias("a")).select("a").collect()
0039         self.assertEqual(result[0][0], 1)
0040         self.assertEqual(result[1][0], 2)
0041         self.assertEqual(result[2][0], 3)
0042 
0043         result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
0044         self.assertEqual(result[0][0], "a")
0045         self.assertEqual(result[0][1], "b")
0046 
0047         result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
0048         self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
0049 
0050         result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
0051         self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
0052 
0053         result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
0054         self.assertEqual(result, [1, 2, 3, None, None])
0055 
0056         result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
0057         self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
0058 
0059     def test_basic_functions(self):
0060         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
0061         df = self.spark.read.json(rdd)
0062         df.count()
0063         df.collect()
0064         df.schema
0065 
0066         # cache and checkpoint
0067         self.assertFalse(df.is_cached)
0068         df.persist()
0069         df.unpersist(True)
0070         df.cache()
0071         self.assertTrue(df.is_cached)
0072         self.assertEqual(2, df.count())
0073 
0074         with self.tempView("temp"):
0075             df.createOrReplaceTempView("temp")
0076             df = self.spark.sql("select foo from temp")
0077             df.count()
0078             df.collect()
0079 
0080     def test_corr(self):
0081         import math
0082         df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
0083         corr = df.stat.corr(u"a", "b")
0084         self.assertTrue(abs(corr - 0.95734012) < 1e-6)
0085 
0086     def test_sampleby(self):
0087         df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(100)]).toDF()
0088         sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
0089         self.assertTrue(sampled.count() == 35)
0090 
0091     def test_cov(self):
0092         df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
0093         cov = df.stat.cov(u"a", "b")
0094         self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
0095 
0096     def test_crosstab(self):
0097         df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
0098         ct = df.stat.crosstab(u"a", "b").collect()
0099         ct = sorted(ct, key=lambda x: x[0])
0100         for i, row in enumerate(ct):
0101             self.assertEqual(row[0], str(i))
0102             self.assertTrue(row[1], 1)
0103             self.assertTrue(row[2], 1)
0104 
0105     def test_math_functions(self):
0106         df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
0107         from pyspark.sql import functions
0108         import math
0109 
0110         def get_values(l):
0111             return [j[0] for j in l]
0112 
0113         def assert_close(a, b):
0114             c = get_values(b)
0115             diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
0116             return sum(diff) == len(a)
0117         assert_close([math.cos(i) for i in range(10)],
0118                      df.select(functions.cos(df.a)).collect())
0119         assert_close([math.cos(i) for i in range(10)],
0120                      df.select(functions.cos("a")).collect())
0121         assert_close([math.sin(i) for i in range(10)],
0122                      df.select(functions.sin(df.a)).collect())
0123         assert_close([math.sin(i) for i in range(10)],
0124                      df.select(functions.sin(df['a'])).collect())
0125         assert_close([math.pow(i, 2 * i) for i in range(10)],
0126                      df.select(functions.pow(df.a, df.b)).collect())
0127         assert_close([math.pow(i, 2) for i in range(10)],
0128                      df.select(functions.pow(df.a, 2)).collect())
0129         assert_close([math.pow(i, 2) for i in range(10)],
0130                      df.select(functions.pow(df.a, 2.0)).collect())
0131         assert_close([math.hypot(i, 2 * i) for i in range(10)],
0132                      df.select(functions.hypot(df.a, df.b)).collect())
0133         assert_close([math.hypot(i, 2 * i) for i in range(10)],
0134                      df.select(functions.hypot("a", u"b")).collect())
0135         assert_close([math.hypot(i, 2) for i in range(10)],
0136                      df.select(functions.hypot("a", 2)).collect())
0137         assert_close([math.hypot(i, 2) for i in range(10)],
0138                      df.select(functions.hypot(df.a, 2)).collect())
0139 
0140     def test_rand_functions(self):
0141         df = self.df
0142         from pyspark.sql import functions
0143         rnd = df.select('key', functions.rand()).collect()
0144         for row in rnd:
0145             assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
0146         rndn = df.select('key', functions.randn(5)).collect()
0147         for row in rndn:
0148             assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
0149 
0150         # If the specified seed is 0, we should use it.
0151         # https://issues.apache.org/jira/browse/SPARK-9691
0152         rnd1 = df.select('key', functions.rand(0)).collect()
0153         rnd2 = df.select('key', functions.rand(0)).collect()
0154         self.assertEqual(sorted(rnd1), sorted(rnd2))
0155 
0156         rndn1 = df.select('key', functions.randn(0)).collect()
0157         rndn2 = df.select('key', functions.randn(0)).collect()
0158         self.assertEqual(sorted(rndn1), sorted(rndn2))
0159 
0160     def test_string_functions(self):
0161         from pyspark.sql import functions
0162         from pyspark.sql.functions import col, lit, _string_functions
0163         df = self.spark.createDataFrame([['nick']], schema=['name'])
0164         self.assertRaisesRegexp(
0165             TypeError,
0166             "must be the same type",
0167             lambda: df.select(col('name').substr(0, lit(1))))
0168         if sys.version_info.major == 2:
0169             self.assertRaises(
0170                 TypeError,
0171                 lambda: df.select(col('name').substr(long(0), long(1))))
0172 
0173         for name in _string_functions.keys():
0174             self.assertEqual(
0175                 df.select(getattr(functions, name)("name")).first()[0],
0176                 df.select(getattr(functions, name)(col("name"))).first()[0])
0177 
0178     def test_array_contains_function(self):
0179         from pyspark.sql.functions import array_contains
0180 
0181         df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
0182         actual = df.select(array_contains(df.data, "1").alias('b')).collect()
0183         self.assertEqual([Row(b=True), Row(b=False)], actual)
0184 
0185     def test_between_function(self):
0186         df = self.sc.parallelize([
0187             Row(a=1, b=2, c=3),
0188             Row(a=2, b=1, c=3),
0189             Row(a=4, b=1, c=4)]).toDF()
0190         self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
0191                          df.filter(df.a.between(df.b, df.c)).collect())
0192 
0193     def test_dayofweek(self):
0194         from pyspark.sql.functions import dayofweek
0195         dt = datetime.datetime(2017, 11, 6)
0196         df = self.spark.createDataFrame([Row(date=dt)])
0197         row = df.select(dayofweek(df.date)).first()
0198         self.assertEqual(row[0], 2)
0199 
0200     def test_expr(self):
0201         from pyspark.sql import functions
0202         row = Row(a="length string", b=75)
0203         df = self.spark.createDataFrame([row])
0204         result = df.select(functions.expr("length(a)")).collect()[0].asDict()
0205         self.assertEqual(13, result["length(a)"])
0206 
0207     # add test for SPARK-10577 (test broadcast join hint)
0208     def test_functions_broadcast(self):
0209         from pyspark.sql.functions import broadcast
0210 
0211         df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
0212         df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
0213 
0214         # equijoin - should be converted into broadcast join
0215         plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
0216         self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
0217 
0218         # no join key -- should not be a broadcast join
0219         plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
0220         self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
0221 
0222         # planner should not crash without a join
0223         broadcast(df1)._jdf.queryExecution().executedPlan()
0224 
0225     def test_first_last_ignorenulls(self):
0226         from pyspark.sql import functions
0227         df = self.spark.range(0, 100)
0228         df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
0229         df3 = df2.select(functions.first(df2.id, False).alias('a'),
0230                          functions.first(df2.id, True).alias('b'),
0231                          functions.last(df2.id, False).alias('c'),
0232                          functions.last(df2.id, True).alias('d'))
0233         self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
0234 
0235     def test_approxQuantile(self):
0236         df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
0237         for f in ["a", u"a"]:
0238             aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
0239             self.assertTrue(isinstance(aq, list))
0240             self.assertEqual(len(aq), 3)
0241         self.assertTrue(all(isinstance(q, float) for q in aq))
0242         aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
0243         self.assertTrue(isinstance(aqs, list))
0244         self.assertEqual(len(aqs), 2)
0245         self.assertTrue(isinstance(aqs[0], list))
0246         self.assertEqual(len(aqs[0]), 3)
0247         self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
0248         self.assertTrue(isinstance(aqs[1], list))
0249         self.assertEqual(len(aqs[1]), 3)
0250         self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
0251         aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
0252         self.assertTrue(isinstance(aqt, list))
0253         self.assertEqual(len(aqt), 2)
0254         self.assertTrue(isinstance(aqt[0], list))
0255         self.assertEqual(len(aqt[0]), 3)
0256         self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
0257         self.assertTrue(isinstance(aqt[1], list))
0258         self.assertEqual(len(aqt[1]), 3)
0259         self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
0260         self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
0261         self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
0262         self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
0263 
0264     def test_sort_with_nulls_order(self):
0265         from pyspark.sql import functions
0266 
0267         df = self.spark.createDataFrame(
0268             [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
0269         self.assertEquals(
0270             df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
0271             [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
0272         self.assertEquals(
0273             df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
0274             [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
0275         self.assertEquals(
0276             df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
0277             [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
0278         self.assertEquals(
0279             df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
0280             [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
0281 
0282     def test_input_file_name_reset_for_rdd(self):
0283         rdd = self.sc.textFile('python/test_support/hello/hello.txt').map(lambda x: {'data': x})
0284         df = self.spark.createDataFrame(rdd, "data STRING")
0285         df.select(input_file_name().alias('file')).collect()
0286 
0287         non_file_df = self.spark.range(100).select(input_file_name())
0288 
0289         results = non_file_df.collect()
0290         self.assertTrue(len(results) == 100)
0291 
0292         # [SPARK-24605]: if everything was properly reset after the last job, this should return
0293         # empty string rather than the file read in the last job.
0294         for result in results:
0295             self.assertEqual(result[0], '')
0296 
0297     def test_array_repeat(self):
0298         from pyspark.sql.functions import array_repeat, lit
0299 
0300         df = self.spark.range(1)
0301 
0302         self.assertEquals(
0303             df.select(array_repeat("id", 3)).toDF("val").collect(),
0304             df.select(array_repeat("id", lit(3))).toDF("val").collect(),
0305         )
0306 
0307     def test_input_file_name_udf(self):
0308         df = self.spark.read.text('python/test_support/hello/hello.txt')
0309         df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file'))
0310         file_name = df.collect()[0].file
0311         self.assertTrue("python/test_support/hello/hello.txt" in file_name)
0312 
0313     def test_overlay(self):
0314         from pyspark.sql.functions import col, lit, overlay
0315         from itertools import chain
0316         import re
0317 
0318         actual = list(chain.from_iterable([
0319             re.findall("(overlay\\(.*\\))", str(x)) for x in [
0320                 overlay(col("foo"), col("bar"), 1),
0321                 overlay("x", "y", 3),
0322                 overlay(col("x"), col("y"), 1, 3),
0323                 overlay("x", "y", 2, 5),
0324                 overlay("x", "y", lit(11)),
0325                 overlay("x", "y", lit(2), lit(5)),
0326             ]
0327         ]))
0328 
0329         expected = [
0330             "overlay(foo, bar, 1, -1)",
0331             "overlay(x, y, 3, -1)",
0332             "overlay(x, y, 1, 3)",
0333             "overlay(x, y, 2, 5)",
0334             "overlay(x, y, 11, -1)",
0335             "overlay(x, y, 2, 5)",
0336         ]
0337 
0338         self.assertListEqual(actual, expected)
0339 
0340 
0341 if __name__ == "__main__":
0342     import unittest
0343     from pyspark.sql.tests.test_functions import *
0344 
0345     try:
0346         import xmlrunner
0347         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0348     except ImportError:
0349         testRunner = None
0350     unittest.main(testRunner=testRunner, verbosity=2)