Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import datetime
0019 import shutil
0020 import tempfile
0021 import time
0022 
0023 from pyspark.sql import Row
0024 from pyspark.sql.functions import lit
0025 from pyspark.sql.types import *
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone
0027 
0028 
0029 class SerdeTests(ReusedSQLTestCase):
0030 
0031     def test_serialize_nested_array_and_map(self):
0032         d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
0033         rdd = self.sc.parallelize(d)
0034         df = self.spark.createDataFrame(rdd)
0035         row = df.head()
0036         self.assertEqual(1, len(row.l))
0037         self.assertEqual(1, row.l[0].a)
0038         self.assertEqual("2", row.d["key"].d)
0039 
0040         l = df.rdd.map(lambda x: x.l).first()
0041         self.assertEqual(1, len(l))
0042         self.assertEqual('s', l[0].b)
0043 
0044         d = df.rdd.map(lambda x: x.d).first()
0045         self.assertEqual(1, len(d))
0046         self.assertEqual(1.0, d["key"].c)
0047 
0048         row = df.rdd.map(lambda x: x.d["key"]).first()
0049         self.assertEqual(1.0, row.c)
0050         self.assertEqual("2", row.d)
0051 
0052     def test_select_null_literal(self):
0053         df = self.spark.sql("select null as col")
0054         self.assertEqual(Row(col=None), df.first())
0055 
0056     def test_struct_in_map(self):
0057         d = [Row(m={Row(i=1): Row(s="")})]
0058         df = self.sc.parallelize(d).toDF()
0059         k, v = list(df.head().m.items())[0]
0060         self.assertEqual(1, k.i)
0061         self.assertEqual("", v.s)
0062 
0063     def test_filter_with_datetime(self):
0064         time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
0065         date = time.date()
0066         row = Row(date=date, time=time)
0067         df = self.spark.createDataFrame([row])
0068         self.assertEqual(1, df.filter(df.date == date).count())
0069         self.assertEqual(1, df.filter(df.time == time).count())
0070         self.assertEqual(0, df.filter(df.date > date).count())
0071         self.assertEqual(0, df.filter(df.time > time).count())
0072 
0073     def test_filter_with_datetime_timezone(self):
0074         dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
0075         dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
0076         row = Row(date=dt1)
0077         df = self.spark.createDataFrame([row])
0078         self.assertEqual(0, df.filter(df.date == dt2).count())
0079         self.assertEqual(1, df.filter(df.date > dt2).count())
0080         self.assertEqual(0, df.filter(df.date < dt2).count())
0081 
0082     def test_time_with_timezone(self):
0083         day = datetime.date.today()
0084         now = datetime.datetime.now()
0085         ts = time.mktime(now.timetuple())
0086         # class in __main__ is not serializable
0087         from pyspark.testing.sqlutils import UTCOffsetTimezone
0088         utc = UTCOffsetTimezone()
0089         utcnow = datetime.datetime.utcfromtimestamp(ts)  # without microseconds
0090         # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
0091         utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
0092         df = self.spark.createDataFrame([(day, now, utcnow)])
0093         day1, now1, utcnow1 = df.first()
0094         self.assertEqual(day1, day)
0095         self.assertEqual(now, now1)
0096         self.assertEqual(now, utcnow1)
0097 
0098     # regression test for SPARK-19561
0099     def test_datetime_at_epoch(self):
0100         epoch = datetime.datetime.fromtimestamp(0)
0101         df = self.spark.createDataFrame([Row(date=epoch)])
0102         first = df.select('date', lit(epoch).alias('lit_date')).first()
0103         self.assertEqual(first['date'], epoch)
0104         self.assertEqual(first['lit_date'], epoch)
0105 
0106     def test_decimal(self):
0107         from decimal import Decimal
0108         schema = StructType([StructField("decimal", DecimalType(10, 5))])
0109         df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
0110         row = df.select(df.decimal + 1).first()
0111         self.assertEqual(row[0], Decimal("4.14159"))
0112         tmpPath = tempfile.mkdtemp()
0113         shutil.rmtree(tmpPath)
0114         df.write.parquet(tmpPath)
0115         df2 = self.spark.read.parquet(tmpPath)
0116         row = df2.first()
0117         self.assertEqual(row[0], Decimal("3.14159"))
0118 
0119     def test_BinaryType_serialization(self):
0120         # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808
0121         # The empty bytearray is test for SPARK-21534.
0122         schema = StructType([StructField('mybytes', BinaryType())])
0123         data = [[bytearray(b'here is my data')],
0124                 [bytearray(b'and here is some more')],
0125                 [bytearray(b'')]]
0126         df = self.spark.createDataFrame(data, schema=schema)
0127         df.collect()
0128 
0129     def test_int_array_serialization(self):
0130         # Note that this test seems dependent on parallelism.
0131         data = self.spark.sparkContext.parallelize([[1, 2, 3, 4]] * 100, numSlices=12)
0132         df = self.spark.createDataFrame(data, "array<integer>")
0133         self.assertEqual(len(list(filter(lambda r: None in r.value, df.collect()))), 0)
0134 
0135     def test_bytes_as_binary_type(self):
0136         df = self.spark.createDataFrame([[b"abcd"]], "col binary")
0137         self.assertEqual(df.first().col, bytearray(b'abcd'))
0138 
0139 
0140 if __name__ == "__main__":
0141     import unittest
0142     from pyspark.sql.tests.test_serde import *
0143 
0144     try:
0145         import xmlrunner
0146         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0147     except ImportError:
0148         testRunner = None
0149     unittest.main(testRunner=testRunner, verbosity=2)