0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019 import sys
0020
0021 from pyspark.sql import Column, Row
0022 from pyspark.sql.types import *
0023 from pyspark.sql.utils import AnalysisException
0024 from pyspark.testing.sqlutils import ReusedSQLTestCase
0025
0026
0027 class ColumnTests(ReusedSQLTestCase):
0028
0029 def test_column_name_encoding(self):
0030 """Ensure that created columns has `str` type consistently."""
0031 columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
0032 self.assertEqual(columns, ['name', 'age'])
0033 self.assertTrue(isinstance(columns[0], str))
0034 self.assertTrue(isinstance(columns[1], str))
0035
0036 def test_and_in_expression(self):
0037 self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
0038 self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
0039 self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
0040 self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
0041 self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
0042 self.assertRaises(ValueError, lambda: not self.df.key == 1)
0043
0044 def test_validate_column_types(self):
0045 from pyspark.sql.functions import udf, to_json
0046 from pyspark.sql.column import _to_java_column
0047
0048 self.assertTrue("Column" in _to_java_column("a").getClass().toString())
0049 self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
0050 self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
0051
0052 self.assertRaisesRegexp(
0053 TypeError,
0054 "Invalid argument, not a string or column",
0055 lambda: _to_java_column(1))
0056
0057 class A():
0058 pass
0059
0060 self.assertRaises(TypeError, lambda: _to_java_column(A()))
0061 self.assertRaises(TypeError, lambda: _to_java_column([]))
0062
0063 self.assertRaisesRegexp(
0064 TypeError,
0065 "Invalid argument, not a string or column",
0066 lambda: udf(lambda x: x)(None))
0067 self.assertRaises(TypeError, lambda: to_json(1))
0068
0069 def test_column_operators(self):
0070 ci = self.df.key
0071 cs = self.df.value
0072 c = ci == cs
0073 self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
0074 rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
0075 self.assertTrue(all(isinstance(c, Column) for c in rcc))
0076 cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
0077 self.assertTrue(all(isinstance(c, Column) for c in cb))
0078 cbool = (ci & ci), (ci | ci), (~ci)
0079 self.assertTrue(all(isinstance(c, Column) for c in cbool))
0080 css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\
0081 cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
0082 self.assertTrue(all(isinstance(c, Column) for c in css))
0083 self.assertTrue(isinstance(ci.cast(LongType()), Column))
0084 self.assertRaisesRegexp(ValueError,
0085 "Cannot apply 'in' operator against a column",
0086 lambda: 1 in cs)
0087
0088 def test_column_accessor(self):
0089 from pyspark.sql.functions import col
0090
0091 self.assertIsInstance(col("foo")[1:3], Column)
0092 self.assertIsInstance(col("foo")[0], Column)
0093 self.assertIsInstance(col("foo")["bar"], Column)
0094 self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
0095
0096 def test_column_select(self):
0097 df = self.df
0098 self.assertEqual(self.testData, df.select("*").collect())
0099 self.assertEqual(self.testData, df.select(df.key, df.value).collect())
0100 self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
0101
0102 def test_access_column(self):
0103 df = self.df
0104 self.assertTrue(isinstance(df.key, Column))
0105 self.assertTrue(isinstance(df['key'], Column))
0106 self.assertTrue(isinstance(df[0], Column))
0107 self.assertRaises(IndexError, lambda: df[2])
0108 self.assertRaises(AnalysisException, lambda: df["bad_key"])
0109 self.assertRaises(TypeError, lambda: df[{}])
0110
0111 def test_column_name_with_non_ascii(self):
0112 if sys.version >= '3':
0113 columnName = "数量"
0114 self.assertTrue(isinstance(columnName, str))
0115 else:
0116 columnName = unicode("数量", "utf-8")
0117 self.assertTrue(isinstance(columnName, unicode))
0118 schema = StructType([StructField(columnName, LongType(), True)])
0119 df = self.spark.createDataFrame([(1,)], schema)
0120 self.assertEqual(schema, df.schema)
0121 self.assertEqual("DataFrame[数量: bigint]", str(df))
0122 self.assertEqual([("数量", 'bigint')], df.dtypes)
0123 self.assertEqual(1, df.select("数量").first()[0])
0124 self.assertEqual(1, df.select(df["数量"]).first()[0])
0125
0126 def test_field_accessor(self):
0127 df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
0128 self.assertEqual(1, df.select(df.l[0]).first()[0])
0129 self.assertEqual(1, df.select(df.r["a"]).first()[0])
0130 self.assertEqual(1, df.select(df["r.a"]).first()[0])
0131 self.assertEqual("b", df.select(df.r["b"]).first()[0])
0132 self.assertEqual("b", df.select(df["r.b"]).first()[0])
0133 self.assertEqual("v", df.select(df.d["k"]).first()[0])
0134
0135 def test_bitwise_operations(self):
0136 from pyspark.sql import functions
0137 row = Row(a=170, b=75)
0138 df = self.spark.createDataFrame([row])
0139 result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
0140 self.assertEqual(170 & 75, result['(a & b)'])
0141 result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
0142 self.assertEqual(170 | 75, result['(a | b)'])
0143 result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
0144 self.assertEqual(170 ^ 75, result['(a ^ b)'])
0145 result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
0146 self.assertEqual(~75, result['~b'])
0147
0148
0149 if __name__ == "__main__":
0150 import unittest
0151 from pyspark.sql.tests.test_column import *
0152
0153 try:
0154 import xmlrunner
0155 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0156 except ImportError:
0157 testRunner = None
0158 unittest.main(testRunner=testRunner, verbosity=2)