Back to home page

OSCL-LXR

 
 

    


0001 # -*- encoding: utf-8 -*-
0002 #
0003 # Licensed to the Apache Software Foundation (ASF) under one or more
0004 # contributor license agreements.  See the NOTICE file distributed with
0005 # this work for additional information regarding copyright ownership.
0006 # The ASF licenses this file to You under the Apache License, Version 2.0
0007 # (the "License"); you may not use this file except in compliance with
0008 # the License.  You may obtain a copy of the License at
0009 #
0010 #    http://www.apache.org/licenses/LICENSE-2.0
0011 #
0012 # Unless required by applicable law or agreed to in writing, software
0013 # distributed under the License is distributed on an "AS IS" BASIS,
0014 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0015 # See the License for the specific language governing permissions and
0016 # limitations under the License.
0017 #
0018 
0019 import sys
0020 
0021 from pyspark.sql import Column, Row
0022 from pyspark.sql.types import *
0023 from pyspark.sql.utils import AnalysisException
0024 from pyspark.testing.sqlutils import ReusedSQLTestCase
0025 
0026 
0027 class ColumnTests(ReusedSQLTestCase):
0028 
0029     def test_column_name_encoding(self):
0030         """Ensure that created columns has `str` type consistently."""
0031         columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
0032         self.assertEqual(columns, ['name', 'age'])
0033         self.assertTrue(isinstance(columns[0], str))
0034         self.assertTrue(isinstance(columns[1], str))
0035 
0036     def test_and_in_expression(self):
0037         self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
0038         self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
0039         self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
0040         self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
0041         self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
0042         self.assertRaises(ValueError, lambda: not self.df.key == 1)
0043 
0044     def test_validate_column_types(self):
0045         from pyspark.sql.functions import udf, to_json
0046         from pyspark.sql.column import _to_java_column
0047 
0048         self.assertTrue("Column" in _to_java_column("a").getClass().toString())
0049         self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
0050         self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
0051 
0052         self.assertRaisesRegexp(
0053             TypeError,
0054             "Invalid argument, not a string or column",
0055             lambda: _to_java_column(1))
0056 
0057         class A():
0058             pass
0059 
0060         self.assertRaises(TypeError, lambda: _to_java_column(A()))
0061         self.assertRaises(TypeError, lambda: _to_java_column([]))
0062 
0063         self.assertRaisesRegexp(
0064             TypeError,
0065             "Invalid argument, not a string or column",
0066             lambda: udf(lambda x: x)(None))
0067         self.assertRaises(TypeError, lambda: to_json(1))
0068 
0069     def test_column_operators(self):
0070         ci = self.df.key
0071         cs = self.df.value
0072         c = ci == cs
0073         self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
0074         rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
0075         self.assertTrue(all(isinstance(c, Column) for c in rcc))
0076         cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
0077         self.assertTrue(all(isinstance(c, Column) for c in cb))
0078         cbool = (ci & ci), (ci | ci), (~ci)
0079         self.assertTrue(all(isinstance(c, Column) for c in cbool))
0080         css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\
0081             cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
0082         self.assertTrue(all(isinstance(c, Column) for c in css))
0083         self.assertTrue(isinstance(ci.cast(LongType()), Column))
0084         self.assertRaisesRegexp(ValueError,
0085                                 "Cannot apply 'in' operator against a column",
0086                                 lambda: 1 in cs)
0087 
0088     def test_column_accessor(self):
0089         from pyspark.sql.functions import col
0090 
0091         self.assertIsInstance(col("foo")[1:3], Column)
0092         self.assertIsInstance(col("foo")[0], Column)
0093         self.assertIsInstance(col("foo")["bar"], Column)
0094         self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
0095 
0096     def test_column_select(self):
0097         df = self.df
0098         self.assertEqual(self.testData, df.select("*").collect())
0099         self.assertEqual(self.testData, df.select(df.key, df.value).collect())
0100         self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
0101 
0102     def test_access_column(self):
0103         df = self.df
0104         self.assertTrue(isinstance(df.key, Column))
0105         self.assertTrue(isinstance(df['key'], Column))
0106         self.assertTrue(isinstance(df[0], Column))
0107         self.assertRaises(IndexError, lambda: df[2])
0108         self.assertRaises(AnalysisException, lambda: df["bad_key"])
0109         self.assertRaises(TypeError, lambda: df[{}])
0110 
0111     def test_column_name_with_non_ascii(self):
0112         if sys.version >= '3':
0113             columnName = "数量"
0114             self.assertTrue(isinstance(columnName, str))
0115         else:
0116             columnName = unicode("数量", "utf-8")
0117             self.assertTrue(isinstance(columnName, unicode))
0118         schema = StructType([StructField(columnName, LongType(), True)])
0119         df = self.spark.createDataFrame([(1,)], schema)
0120         self.assertEqual(schema, df.schema)
0121         self.assertEqual("DataFrame[数量: bigint]", str(df))
0122         self.assertEqual([("数量", 'bigint')], df.dtypes)
0123         self.assertEqual(1, df.select("数量").first()[0])
0124         self.assertEqual(1, df.select(df["数量"]).first()[0])
0125 
0126     def test_field_accessor(self):
0127         df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
0128         self.assertEqual(1, df.select(df.l[0]).first()[0])
0129         self.assertEqual(1, df.select(df.r["a"]).first()[0])
0130         self.assertEqual(1, df.select(df["r.a"]).first()[0])
0131         self.assertEqual("b", df.select(df.r["b"]).first()[0])
0132         self.assertEqual("b", df.select(df["r.b"]).first()[0])
0133         self.assertEqual("v", df.select(df.d["k"]).first()[0])
0134 
0135     def test_bitwise_operations(self):
0136         from pyspark.sql import functions
0137         row = Row(a=170, b=75)
0138         df = self.spark.createDataFrame([row])
0139         result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
0140         self.assertEqual(170 & 75, result['(a & b)'])
0141         result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
0142         self.assertEqual(170 | 75, result['(a | b)'])
0143         result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
0144         self.assertEqual(170 ^ 75, result['(a ^ b)'])
0145         result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
0146         self.assertEqual(~75, result['~b'])
0147 
0148 
0149 if __name__ == "__main__":
0150     import unittest
0151     from pyspark.sql.tests.test_column import *
0152 
0153     try:
0154         import xmlrunner
0155         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0156     except ImportError:
0157         testRunner = None
0158     unittest.main(testRunner=testRunner, verbosity=2)