0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import datetime
0019 import shutil
0020 import tempfile
0021 import time
0022
0023 from pyspark.sql import Row
0024 from pyspark.sql.functions import lit
0025 from pyspark.sql.types import *
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone
0027
0028
0029 class SerdeTests(ReusedSQLTestCase):
0030
0031 def test_serialize_nested_array_and_map(self):
0032 d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
0033 rdd = self.sc.parallelize(d)
0034 df = self.spark.createDataFrame(rdd)
0035 row = df.head()
0036 self.assertEqual(1, len(row.l))
0037 self.assertEqual(1, row.l[0].a)
0038 self.assertEqual("2", row.d["key"].d)
0039
0040 l = df.rdd.map(lambda x: x.l).first()
0041 self.assertEqual(1, len(l))
0042 self.assertEqual('s', l[0].b)
0043
0044 d = df.rdd.map(lambda x: x.d).first()
0045 self.assertEqual(1, len(d))
0046 self.assertEqual(1.0, d["key"].c)
0047
0048 row = df.rdd.map(lambda x: x.d["key"]).first()
0049 self.assertEqual(1.0, row.c)
0050 self.assertEqual("2", row.d)
0051
0052 def test_select_null_literal(self):
0053 df = self.spark.sql("select null as col")
0054 self.assertEqual(Row(col=None), df.first())
0055
0056 def test_struct_in_map(self):
0057 d = [Row(m={Row(i=1): Row(s="")})]
0058 df = self.sc.parallelize(d).toDF()
0059 k, v = list(df.head().m.items())[0]
0060 self.assertEqual(1, k.i)
0061 self.assertEqual("", v.s)
0062
0063 def test_filter_with_datetime(self):
0064 time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
0065 date = time.date()
0066 row = Row(date=date, time=time)
0067 df = self.spark.createDataFrame([row])
0068 self.assertEqual(1, df.filter(df.date == date).count())
0069 self.assertEqual(1, df.filter(df.time == time).count())
0070 self.assertEqual(0, df.filter(df.date > date).count())
0071 self.assertEqual(0, df.filter(df.time > time).count())
0072
0073 def test_filter_with_datetime_timezone(self):
0074 dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
0075 dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
0076 row = Row(date=dt1)
0077 df = self.spark.createDataFrame([row])
0078 self.assertEqual(0, df.filter(df.date == dt2).count())
0079 self.assertEqual(1, df.filter(df.date > dt2).count())
0080 self.assertEqual(0, df.filter(df.date < dt2).count())
0081
0082 def test_time_with_timezone(self):
0083 day = datetime.date.today()
0084 now = datetime.datetime.now()
0085 ts = time.mktime(now.timetuple())
0086
0087 from pyspark.testing.sqlutils import UTCOffsetTimezone
0088 utc = UTCOffsetTimezone()
0089 utcnow = datetime.datetime.utcfromtimestamp(ts)
0090
0091 utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
0092 df = self.spark.createDataFrame([(day, now, utcnow)])
0093 day1, now1, utcnow1 = df.first()
0094 self.assertEqual(day1, day)
0095 self.assertEqual(now, now1)
0096 self.assertEqual(now, utcnow1)
0097
0098
0099 def test_datetime_at_epoch(self):
0100 epoch = datetime.datetime.fromtimestamp(0)
0101 df = self.spark.createDataFrame([Row(date=epoch)])
0102 first = df.select('date', lit(epoch).alias('lit_date')).first()
0103 self.assertEqual(first['date'], epoch)
0104 self.assertEqual(first['lit_date'], epoch)
0105
0106 def test_decimal(self):
0107 from decimal import Decimal
0108 schema = StructType([StructField("decimal", DecimalType(10, 5))])
0109 df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
0110 row = df.select(df.decimal + 1).first()
0111 self.assertEqual(row[0], Decimal("4.14159"))
0112 tmpPath = tempfile.mkdtemp()
0113 shutil.rmtree(tmpPath)
0114 df.write.parquet(tmpPath)
0115 df2 = self.spark.read.parquet(tmpPath)
0116 row = df2.first()
0117 self.assertEqual(row[0], Decimal("3.14159"))
0118
0119 def test_BinaryType_serialization(self):
0120
0121
0122 schema = StructType([StructField('mybytes', BinaryType())])
0123 data = [[bytearray(b'here is my data')],
0124 [bytearray(b'and here is some more')],
0125 [bytearray(b'')]]
0126 df = self.spark.createDataFrame(data, schema=schema)
0127 df.collect()
0128
0129 def test_int_array_serialization(self):
0130
0131 data = self.spark.sparkContext.parallelize([[1, 2, 3, 4]] * 100, numSlices=12)
0132 df = self.spark.createDataFrame(data, "array<integer>")
0133 self.assertEqual(len(list(filter(lambda r: None in r.value, df.collect()))), 0)
0134
0135 def test_bytes_as_binary_type(self):
0136 df = self.spark.createDataFrame([[b"abcd"]], "col binary")
0137 self.assertEqual(df.first().col, bytearray(b'abcd'))
0138
0139
0140 if __name__ == "__main__":
0141 import unittest
0142 from pyspark.sql.tests.test_serde import *
0143
0144 try:
0145 import xmlrunner
0146 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0147 except ImportError:
0148 testRunner = None
0149 unittest.main(testRunner=testRunner, verbosity=2)