Back to home page

OSCL-LXR

 
 

    


0001 # -*- encoding: utf-8 -*-
0002 #
0003 # Licensed to the Apache Software Foundation (ASF) under one or more
0004 # contributor license agreements.  See the NOTICE file distributed with
0005 # this work for additional information regarding copyright ownership.
0006 # The ASF licenses this file to You under the Apache License, Version 2.0
0007 # (the "License"); you may not use this file except in compliance with
0008 # the License.  You may obtain a copy of the License at
0009 #
0010 #    http://www.apache.org/licenses/LICENSE-2.0
0011 #
0012 # Unless required by applicable law or agreed to in writing, software
0013 # distributed under the License is distributed on an "AS IS" BASIS,
0014 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0015 # See the License for the specific language governing permissions and
0016 # limitations under the License.
0017 #
0018 
0019 import array
0020 import ctypes
0021 import datetime
0022 import os
0023 import pickle
0024 import sys
0025 import unittest
0026 
0027 from pyspark.sql import Row
0028 from pyspark.sql.functions import col, UserDefinedFunction
0029 from pyspark.sql.types import *
0030 from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings, \
0031     _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type
0032 from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, PythonOnlyUDT, \
0033     ExamplePoint, PythonOnlyPoint, MyObject
0034 
0035 
0036 class TypesTests(ReusedSQLTestCase):
0037 
0038     def test_apply_schema_to_row(self):
0039         df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
0040         df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
0041         self.assertEqual(df.collect(), df2.collect())
0042 
0043         rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
0044         df3 = self.spark.createDataFrame(rdd, df.schema)
0045         self.assertEqual(10, df3.count())
0046 
0047     def test_infer_schema_to_local(self):
0048         input = [{"a": 1}, {"b": "coffee"}]
0049         rdd = self.sc.parallelize(input)
0050         df = self.spark.createDataFrame(input)
0051         df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
0052         self.assertEqual(df.schema, df2.schema)
0053 
0054         rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
0055         df3 = self.spark.createDataFrame(rdd, df.schema)
0056         self.assertEqual(10, df3.count())
0057 
0058     def test_apply_schema_to_dict_and_rows(self):
0059         schema = StructType().add("b", StringType()).add("a", IntegerType())
0060         input = [{"a": 1}, {"b": "coffee"}]
0061         rdd = self.sc.parallelize(input)
0062         for verify in [False, True]:
0063             df = self.spark.createDataFrame(input, schema, verifySchema=verify)
0064             df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
0065             self.assertEqual(df.schema, df2.schema)
0066 
0067             rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
0068             df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
0069             self.assertEqual(10, df3.count())
0070             input = [Row(a=x, b=str(x)) for x in range(10)]
0071             df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
0072             self.assertEqual(10, df4.count())
0073 
0074     def test_create_dataframe_schema_mismatch(self):
0075         input = [Row(a=1)]
0076         rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
0077         schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
0078         df = self.spark.createDataFrame(rdd, schema)
0079         self.assertRaises(Exception, lambda: df.show())
0080 
0081     def test_infer_schema(self):
0082         d = [Row(l=[], d={}, s=None),
0083              Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
0084         rdd = self.sc.parallelize(d)
0085         df = self.spark.createDataFrame(rdd)
0086         self.assertEqual([], df.rdd.map(lambda r: r.l).first())
0087         self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
0088 
0089         with self.tempView("test"):
0090             df.createOrReplaceTempView("test")
0091             result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
0092             self.assertEqual(1, result.head()[0])
0093 
0094         df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
0095         self.assertEqual(df.schema, df2.schema)
0096         self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
0097         self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
0098 
0099         with self.tempView("test2"):
0100             df2.createOrReplaceTempView("test2")
0101             result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
0102             self.assertEqual(1, result.head()[0])
0103 
0104     def test_infer_schema_specification(self):
0105         from decimal import Decimal
0106 
0107         class A(object):
0108             def __init__(self):
0109                 self.a = 1
0110 
0111         data = [
0112             True,
0113             1,
0114             "a",
0115             u"a",
0116             datetime.date(1970, 1, 1),
0117             datetime.datetime(1970, 1, 1, 0, 0),
0118             1.0,
0119             array.array("d", [1]),
0120             [1],
0121             (1, ),
0122             {"a": 1},
0123             bytearray(1),
0124             Decimal(1),
0125             Row(a=1),
0126             Row("a")(1),
0127             A(),
0128         ]
0129 
0130         df = self.spark.createDataFrame([data])
0131         actual = list(map(lambda x: x.dataType.simpleString(), df.schema))
0132         expected = [
0133             'boolean',
0134             'bigint',
0135             'string',
0136             'string',
0137             'date',
0138             'timestamp',
0139             'double',
0140             'array<double>',
0141             'array<bigint>',
0142             'struct<_1:bigint>',
0143             'map<string,bigint>',
0144             'binary',
0145             'decimal(38,18)',
0146             'struct<a:bigint>',
0147             'struct<a:bigint>',
0148             'struct<a:bigint>',
0149         ]
0150         self.assertEqual(actual, expected)
0151 
0152         actual = list(df.first())
0153         expected = [
0154             True,
0155             1,
0156             'a',
0157             u"a",
0158             datetime.date(1970, 1, 1),
0159             datetime.datetime(1970, 1, 1, 0, 0),
0160             1.0,
0161             [1.0],
0162             [1],
0163             Row(_1=1),
0164             {"a": 1},
0165             bytearray(b'\x00'),
0166             Decimal('1.000000000000000000'),
0167             Row(a=1),
0168             Row(a=1),
0169             Row(a=1),
0170         ]
0171         self.assertEqual(actual, expected)
0172 
0173     def test_infer_schema_not_enough_names(self):
0174         df = self.spark.createDataFrame([["a", "b"]], ["col1"])
0175         self.assertEqual(df.columns, ['col1', '_2'])
0176 
0177     def test_infer_schema_fails(self):
0178         with self.assertRaisesRegexp(TypeError, 'field a'):
0179             self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
0180                                        schema=["a", "b"], samplingRatio=0.99)
0181 
0182     def test_infer_nested_schema(self):
0183         NestedRow = Row("f1", "f2")
0184         nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
0185                                           NestedRow([2, 3], {"row2": 2.0})])
0186         df = self.spark.createDataFrame(nestedRdd1)
0187         self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
0188 
0189         nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
0190                                           NestedRow([[2, 3], [3, 4]], [2, 3])])
0191         df = self.spark.createDataFrame(nestedRdd2)
0192         self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
0193 
0194         from collections import namedtuple
0195         CustomRow = namedtuple('CustomRow', 'field1 field2')
0196         rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
0197                                    CustomRow(field1=2, field2="row2"),
0198                                    CustomRow(field1=3, field2="row3")])
0199         df = self.spark.createDataFrame(rdd)
0200         self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
0201 
0202     def test_create_dataframe_from_dict_respects_schema(self):
0203         df = self.spark.createDataFrame([{'a': 1}], ["b"])
0204         self.assertEqual(df.columns, ['b'])
0205 
0206     def test_negative_decimal(self):
0207         try:
0208             self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
0209             df = self.spark.createDataFrame([(1, ), (11, )], ["value"])
0210             ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
0211             actual = list(map(lambda r: int(r.value), ret))
0212             self.assertEqual(actual, [0, 10])
0213         finally:
0214             self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false")
0215 
0216     def test_create_dataframe_from_objects(self):
0217         data = [MyObject(1, "1"), MyObject(2, "2")]
0218         df = self.spark.createDataFrame(data)
0219         self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
0220         self.assertEqual(df.first(), Row(key=1, value="1"))
0221 
0222     def test_apply_schema(self):
0223         from datetime import date, datetime
0224         rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
0225                                     date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
0226                                     {"a": 1}, (2,), [1, 2, 3], None)])
0227         schema = StructType([
0228             StructField("byte1", ByteType(), False),
0229             StructField("byte2", ByteType(), False),
0230             StructField("short1", ShortType(), False),
0231             StructField("short2", ShortType(), False),
0232             StructField("int1", IntegerType(), False),
0233             StructField("float1", FloatType(), False),
0234             StructField("date1", DateType(), False),
0235             StructField("time1", TimestampType(), False),
0236             StructField("map1", MapType(StringType(), IntegerType(), False), False),
0237             StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
0238             StructField("list1", ArrayType(ByteType(), False), False),
0239             StructField("null1", DoubleType(), True)])
0240         df = self.spark.createDataFrame(rdd, schema)
0241         results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
0242                              x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
0243         r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
0244              datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
0245         self.assertEqual(r, results.first())
0246 
0247         with self.tempView("table2"):
0248             df.createOrReplaceTempView("table2")
0249             r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
0250                                "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
0251                                "float1 + 1.5 as float1 FROM table2").first()
0252 
0253             self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
0254 
0255     def test_convert_row_to_dict(self):
0256         row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
0257         self.assertEqual(1, row.asDict()['l'][0].a)
0258         df = self.sc.parallelize([row]).toDF()
0259 
0260         with self.tempView("test"):
0261             df.createOrReplaceTempView("test")
0262             row = self.spark.sql("select l, d from test").head()
0263             self.assertEqual(1, row.asDict()["l"][0].a)
0264             self.assertEqual(1.0, row.asDict()['d']['key'].c)
0265 
0266     def test_udt(self):
0267         from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
0268 
0269         def check_datatype(datatype):
0270             pickled = pickle.loads(pickle.dumps(datatype))
0271             assert datatype == pickled
0272             scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
0273             python_datatype = _parse_datatype_json_string(scala_datatype.json())
0274             assert datatype == python_datatype
0275 
0276         check_datatype(ExamplePointUDT())
0277         structtype_with_udt = StructType([StructField("label", DoubleType(), False),
0278                                           StructField("point", ExamplePointUDT(), False)])
0279         check_datatype(structtype_with_udt)
0280         p = ExamplePoint(1.0, 2.0)
0281         self.assertEqual(_infer_type(p), ExamplePointUDT())
0282         _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
0283         self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
0284 
0285         check_datatype(PythonOnlyUDT())
0286         structtype_with_udt = StructType([StructField("label", DoubleType(), False),
0287                                           StructField("point", PythonOnlyUDT(), False)])
0288         check_datatype(structtype_with_udt)
0289         p = PythonOnlyPoint(1.0, 2.0)
0290         self.assertEqual(_infer_type(p), PythonOnlyUDT())
0291         _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
0292         self.assertRaises(
0293             ValueError,
0294             lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
0295 
0296     def test_simple_udt_in_df(self):
0297         schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
0298         df = self.spark.createDataFrame(
0299             [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
0300             schema=schema)
0301         df.collect()
0302 
0303     def test_nested_udt_in_df(self):
0304         schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
0305         df = self.spark.createDataFrame(
0306             [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
0307             schema=schema)
0308         df.collect()
0309 
0310         schema = StructType().add("key", LongType()).add("val",
0311                                                          MapType(LongType(), PythonOnlyUDT()))
0312         df = self.spark.createDataFrame(
0313             [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
0314             schema=schema)
0315         df.collect()
0316 
0317     def test_complex_nested_udt_in_df(self):
0318         from pyspark.sql.functions import udf
0319 
0320         schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
0321         df = self.spark.createDataFrame(
0322             [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
0323             schema=schema)
0324         df.collect()
0325 
0326         gd = df.groupby("key").agg({"val": "collect_list"})
0327         gd.collect()
0328         udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
0329         gd.select(udf(*gd)).collect()
0330 
0331     def test_udt_with_none(self):
0332         df = self.spark.range(0, 10, 1, 1)
0333 
0334         def myudf(x):
0335             if x > 0:
0336                 return PythonOnlyPoint(float(x), float(x))
0337 
0338         self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
0339         rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
0340         self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
0341 
0342     def test_infer_schema_with_udt(self):
0343         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
0344         df = self.spark.createDataFrame([row])
0345         schema = df.schema
0346         field = [f for f in schema.fields if f.name == "point"][0]
0347         self.assertEqual(type(field.dataType), ExamplePointUDT)
0348 
0349         with self.tempView("labeled_point"):
0350             df.createOrReplaceTempView("labeled_point")
0351             point = self.spark.sql("SELECT point FROM labeled_point").head().point
0352             self.assertEqual(point, ExamplePoint(1.0, 2.0))
0353 
0354         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
0355         df = self.spark.createDataFrame([row])
0356         schema = df.schema
0357         field = [f for f in schema.fields if f.name == "point"][0]
0358         self.assertEqual(type(field.dataType), PythonOnlyUDT)
0359 
0360         with self.tempView("labeled_point"):
0361             df.createOrReplaceTempView("labeled_point")
0362             point = self.spark.sql("SELECT point FROM labeled_point").head().point
0363             self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
0364 
0365     def test_apply_schema_with_udt(self):
0366         row = (1.0, ExamplePoint(1.0, 2.0))
0367         schema = StructType([StructField("label", DoubleType(), False),
0368                              StructField("point", ExamplePointUDT(), False)])
0369         df = self.spark.createDataFrame([row], schema)
0370         point = df.head().point
0371         self.assertEqual(point, ExamplePoint(1.0, 2.0))
0372 
0373         row = (1.0, PythonOnlyPoint(1.0, 2.0))
0374         schema = StructType([StructField("label", DoubleType(), False),
0375                              StructField("point", PythonOnlyUDT(), False)])
0376         df = self.spark.createDataFrame([row], schema)
0377         point = df.head().point
0378         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
0379 
0380     def test_udf_with_udt(self):
0381         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
0382         df = self.spark.createDataFrame([row])
0383         self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
0384         udf = UserDefinedFunction(lambda p: p.y, DoubleType())
0385         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
0386         udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
0387         self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
0388 
0389         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
0390         df = self.spark.createDataFrame([row])
0391         self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
0392         udf = UserDefinedFunction(lambda p: p.y, DoubleType())
0393         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
0394         udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
0395         self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
0396 
0397     def test_parquet_with_udt(self):
0398         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
0399         df0 = self.spark.createDataFrame([row])
0400         output_dir = os.path.join(self.tempdir.name, "labeled_point")
0401         df0.write.parquet(output_dir)
0402         df1 = self.spark.read.parquet(output_dir)
0403         point = df1.head().point
0404         self.assertEqual(point, ExamplePoint(1.0, 2.0))
0405 
0406         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
0407         df0 = self.spark.createDataFrame([row])
0408         df0.write.parquet(output_dir, mode='overwrite')
0409         df1 = self.spark.read.parquet(output_dir)
0410         point = df1.head().point
0411         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
0412 
0413     def test_union_with_udt(self):
0414         row1 = (1.0, ExamplePoint(1.0, 2.0))
0415         row2 = (2.0, ExamplePoint(3.0, 4.0))
0416         schema = StructType([StructField("label", DoubleType(), False),
0417                              StructField("point", ExamplePointUDT(), False)])
0418         df1 = self.spark.createDataFrame([row1], schema)
0419         df2 = self.spark.createDataFrame([row2], schema)
0420 
0421         result = df1.union(df2).orderBy("label").collect()
0422         self.assertEqual(
0423             result,
0424             [
0425                 Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
0426                 Row(label=2.0, point=ExamplePoint(3.0, 4.0))
0427             ]
0428         )
0429 
0430     def test_cast_to_string_with_udt(self):
0431         from pyspark.sql.functions import col
0432         row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
0433         schema = StructType([StructField("point", ExamplePointUDT(), False),
0434                              StructField("pypoint", PythonOnlyUDT(), False)])
0435         df = self.spark.createDataFrame([row], schema)
0436 
0437         result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
0438         self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
0439 
0440     def test_struct_type(self):
0441         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0442         struct2 = StructType([StructField("f1", StringType(), True),
0443                               StructField("f2", StringType(), True, None)])
0444         self.assertEqual(struct1.fieldNames(), struct2.names)
0445         self.assertEqual(struct1, struct2)
0446 
0447         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0448         struct2 = StructType([StructField("f1", StringType(), True)])
0449         self.assertNotEqual(struct1.fieldNames(), struct2.names)
0450         self.assertNotEqual(struct1, struct2)
0451 
0452         struct1 = (StructType().add(StructField("f1", StringType(), True))
0453                    .add(StructField("f2", StringType(), True, None)))
0454         struct2 = StructType([StructField("f1", StringType(), True),
0455                               StructField("f2", StringType(), True, None)])
0456         self.assertEqual(struct1.fieldNames(), struct2.names)
0457         self.assertEqual(struct1, struct2)
0458 
0459         struct1 = (StructType().add(StructField("f1", StringType(), True))
0460                    .add(StructField("f2", StringType(), True, None)))
0461         struct2 = StructType([StructField("f1", StringType(), True)])
0462         self.assertNotEqual(struct1.fieldNames(), struct2.names)
0463         self.assertNotEqual(struct1, struct2)
0464 
0465         # Catch exception raised during improper construction
0466         self.assertRaises(ValueError, lambda: StructType().add("name"))
0467 
0468         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0469         for field in struct1:
0470             self.assertIsInstance(field, StructField)
0471 
0472         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0473         self.assertEqual(len(struct1), 2)
0474 
0475         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0476         self.assertIs(struct1["f1"], struct1.fields[0])
0477         self.assertIs(struct1[0], struct1.fields[0])
0478         self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
0479         self.assertRaises(KeyError, lambda: struct1["f9"])
0480         self.assertRaises(IndexError, lambda: struct1[9])
0481         self.assertRaises(TypeError, lambda: struct1[9.9])
0482 
0483     def test_parse_datatype_string(self):
0484         from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
0485         for k, t in _all_atomic_types.items():
0486             if t != NullType:
0487                 self.assertEqual(t(), _parse_datatype_string(k))
0488         self.assertEqual(IntegerType(), _parse_datatype_string("int"))
0489         self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1  ,1)"))
0490         self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
0491         self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
0492         self.assertEqual(
0493             ArrayType(IntegerType()),
0494             _parse_datatype_string("array<int >"))
0495         self.assertEqual(
0496             MapType(IntegerType(), DoubleType()),
0497             _parse_datatype_string("map< int, double  >"))
0498         self.assertEqual(
0499             StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
0500             _parse_datatype_string("struct<a:int, c:double >"))
0501         self.assertEqual(
0502             StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
0503             _parse_datatype_string("a:int, c:double"))
0504         self.assertEqual(
0505             StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
0506             _parse_datatype_string("a INT, c DOUBLE"))
0507 
0508     def test_metadata_null(self):
0509         schema = StructType([StructField("f1", StringType(), True, None),
0510                              StructField("f2", StringType(), True, {'a': None})])
0511         rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
0512         self.spark.createDataFrame(rdd, schema)
0513 
0514     def test_access_nested_types(self):
0515         df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
0516         self.assertEqual(1, df.select(df.l[0]).first()[0])
0517         self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
0518         self.assertEqual(1, df.select(df.r.a).first()[0])
0519         self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
0520         self.assertEqual("v", df.select(df.d["k"]).first()[0])
0521         self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
0522 
0523     def test_infer_long_type(self):
0524         longrow = [Row(f1='a', f2=100000000000000)]
0525         df = self.sc.parallelize(longrow).toDF()
0526         self.assertEqual(df.schema.fields[1].dataType, LongType())
0527 
0528         # this saving as Parquet caused issues as well.
0529         output_dir = os.path.join(self.tempdir.name, "infer_long_type")
0530         df.write.parquet(output_dir)
0531         df1 = self.spark.read.parquet(output_dir)
0532         self.assertEqual('a', df1.first().f1)
0533         self.assertEqual(100000000000000, df1.first().f2)
0534 
0535         self.assertEqual(_infer_type(1), LongType())
0536         self.assertEqual(_infer_type(2**10), LongType())
0537         self.assertEqual(_infer_type(2**20), LongType())
0538         self.assertEqual(_infer_type(2**31 - 1), LongType())
0539         self.assertEqual(_infer_type(2**31), LongType())
0540         self.assertEqual(_infer_type(2**61), LongType())
0541         self.assertEqual(_infer_type(2**71), LongType())
0542 
0543     @unittest.skipIf(sys.version < "3", "only Python 3 infers bytes as binary type")
0544     def test_infer_binary_type(self):
0545         binaryrow = [Row(f1='a', f2=b"abcd")]
0546         df = self.sc.parallelize(binaryrow).toDF()
0547         self.assertEqual(df.schema.fields[1].dataType, BinaryType())
0548 
0549         # this saving as Parquet caused issues as well.
0550         output_dir = os.path.join(self.tempdir.name, "infer_binary_type")
0551         df.write.parquet(output_dir)
0552         df1 = self.spark.read.parquet(output_dir)
0553         self.assertEqual('a', df1.first().f1)
0554         self.assertEqual(b"abcd", df1.first().f2)
0555 
0556         self.assertEqual(_infer_type(b""), BinaryType())
0557         self.assertEqual(_infer_type(b"1234"), BinaryType())
0558 
0559     def test_merge_type(self):
0560         self.assertEqual(_merge_type(LongType(), NullType()), LongType())
0561         self.assertEqual(_merge_type(NullType(), LongType()), LongType())
0562 
0563         self.assertEqual(_merge_type(LongType(), LongType()), LongType())
0564 
0565         self.assertEqual(_merge_type(
0566             ArrayType(LongType()),
0567             ArrayType(LongType())
0568         ), ArrayType(LongType()))
0569         with self.assertRaisesRegexp(TypeError, 'element in array'):
0570             _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
0571 
0572         self.assertEqual(_merge_type(
0573             MapType(StringType(), LongType()),
0574             MapType(StringType(), LongType())
0575         ), MapType(StringType(), LongType()))
0576         with self.assertRaisesRegexp(TypeError, 'key of map'):
0577             _merge_type(
0578                 MapType(StringType(), LongType()),
0579                 MapType(DoubleType(), LongType()))
0580         with self.assertRaisesRegexp(TypeError, 'value of map'):
0581             _merge_type(
0582                 MapType(StringType(), LongType()),
0583                 MapType(StringType(), DoubleType()))
0584 
0585         self.assertEqual(_merge_type(
0586             StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
0587             StructType([StructField("f1", LongType()), StructField("f2", StringType())])
0588         ), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
0589         with self.assertRaisesRegexp(TypeError, 'field f1'):
0590             _merge_type(
0591                 StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
0592                 StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
0593 
0594         self.assertEqual(_merge_type(
0595             StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
0596             StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
0597         ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
0598         with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
0599             _merge_type(
0600                 StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
0601                 StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
0602 
0603         self.assertEqual(_merge_type(
0604             StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
0605             StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
0606         ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
0607         with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
0608             _merge_type(
0609                 StructType([
0610                     StructField("f1", ArrayType(LongType())),
0611                     StructField("f2", StringType())]),
0612                 StructType([
0613                     StructField("f1", ArrayType(DoubleType())),
0614                     StructField("f2", StringType())]))
0615 
0616         self.assertEqual(_merge_type(
0617             StructType([
0618                 StructField("f1", MapType(StringType(), LongType())),
0619                 StructField("f2", StringType())]),
0620             StructType([
0621                 StructField("f1", MapType(StringType(), LongType())),
0622                 StructField("f2", StringType())])
0623         ), StructType([
0624             StructField("f1", MapType(StringType(), LongType())),
0625             StructField("f2", StringType())]))
0626         with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
0627             _merge_type(
0628                 StructType([
0629                     StructField("f1", MapType(StringType(), LongType())),
0630                     StructField("f2", StringType())]),
0631                 StructType([
0632                     StructField("f1", MapType(StringType(), DoubleType())),
0633                     StructField("f2", StringType())]))
0634 
0635         self.assertEqual(_merge_type(
0636             StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
0637             StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
0638         ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
0639         with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
0640             _merge_type(
0641                 StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
0642                 StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
0643             )
0644 
0645     # test for SPARK-16542
0646     def test_array_types(self):
0647         # This test need to make sure that the Scala type selected is at least
0648         # as large as the python's types. This is necessary because python's
0649         # array types depend on C implementation on the machine. Therefore there
0650         # is no machine independent correspondence between python's array types
0651         # and Scala types.
0652         # See: https://docs.python.org/2/library/array.html
0653 
0654         def assertCollectSuccess(typecode, value):
0655             row = Row(myarray=array.array(typecode, [value]))
0656             df = self.spark.createDataFrame([row])
0657             self.assertEqual(df.first()["myarray"][0], value)
0658 
0659         # supported string types
0660         #
0661         # String types in python's array are "u" for Py_UNICODE and "c" for char.
0662         # "u" will be removed in python 4, and "c" is not supported in python 3.
0663         supported_string_types = []
0664         if sys.version_info[0] < 4:
0665             supported_string_types += ['u']
0666             # test unicode
0667             assertCollectSuccess('u', u'a')
0668         if sys.version_info[0] < 3:
0669             supported_string_types += ['c']
0670             # test string
0671             assertCollectSuccess('c', 'a')
0672 
0673         # supported float and double
0674         #
0675         # Test max, min, and precision for float and double, assuming IEEE 754
0676         # floating-point format.
0677         supported_fractional_types = ['f', 'd']
0678         assertCollectSuccess('f', ctypes.c_float(1e+38).value)
0679         assertCollectSuccess('f', ctypes.c_float(1e-38).value)
0680         assertCollectSuccess('f', ctypes.c_float(1.123456).value)
0681         assertCollectSuccess('d', sys.float_info.max)
0682         assertCollectSuccess('d', sys.float_info.min)
0683         assertCollectSuccess('d', sys.float_info.epsilon)
0684 
0685         # supported signed int types
0686         #
0687         # The size of C types changes with implementation, we need to make sure
0688         # that there is no overflow error on the platform running this test.
0689         supported_signed_int_types = list(
0690             set(_array_signed_int_typecode_ctype_mappings.keys())
0691             .intersection(set(_array_type_mappings.keys())))
0692         for t in supported_signed_int_types:
0693             ctype = _array_signed_int_typecode_ctype_mappings[t]
0694             max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
0695             assertCollectSuccess(t, max_val - 1)
0696             assertCollectSuccess(t, -max_val)
0697 
0698         # supported unsigned int types
0699         #
0700         # JVM does not have unsigned types. We need to be very careful to make
0701         # sure that there is no overflow error.
0702         supported_unsigned_int_types = list(
0703             set(_array_unsigned_int_typecode_ctype_mappings.keys())
0704             .intersection(set(_array_type_mappings.keys())))
0705         for t in supported_unsigned_int_types:
0706             ctype = _array_unsigned_int_typecode_ctype_mappings[t]
0707             assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
0708 
0709         # all supported types
0710         #
0711         # Make sure the types tested above:
0712         # 1. are all supported types
0713         # 2. cover all supported types
0714         supported_types = (supported_string_types +
0715                            supported_fractional_types +
0716                            supported_signed_int_types +
0717                            supported_unsigned_int_types)
0718         self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
0719 
0720         # all unsupported types
0721         #
0722         # Keys in _array_type_mappings is a complete list of all supported types,
0723         # and types not in _array_type_mappings are considered unsupported.
0724         # `array.typecodes` are not supported in python 2.
0725         if sys.version_info[0] < 3:
0726             all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd'])
0727         else:
0728             all_types = set(array.typecodes)
0729         unsupported_types = all_types - set(supported_types)
0730         # test unsupported types
0731         for t in unsupported_types:
0732             with self.assertRaises(TypeError):
0733                 a = array.array(t)
0734                 self.spark.createDataFrame([Row(myarray=a)]).collect()
0735 
0736 
0737 class DataTypeTests(unittest.TestCase):
0738     # regression test for SPARK-6055
0739     def test_data_type_eq(self):
0740         lt = LongType()
0741         lt2 = pickle.loads(pickle.dumps(LongType()))
0742         self.assertEqual(lt, lt2)
0743 
0744     # regression test for SPARK-7978
0745     def test_decimal_type(self):
0746         t1 = DecimalType()
0747         t2 = DecimalType(10, 2)
0748         self.assertTrue(t2 is not t1)
0749         self.assertNotEqual(t1, t2)
0750         t3 = DecimalType(8)
0751         self.assertNotEqual(t2, t3)
0752 
0753     # regression test for SPARK-10392
0754     def test_datetype_equal_zero(self):
0755         dt = DateType()
0756         self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
0757 
0758     # regression test for SPARK-17035
0759     def test_timestamp_microsecond(self):
0760         tst = TimestampType()
0761         self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999)
0762 
0763     # regression test for SPARK-23299
0764     def test_row_without_column_name(self):
0765         row = Row("Alice", 11)
0766         self.assertEqual(repr(row), "<Row('Alice', 11)>")
0767 
0768         # test __repr__ with unicode values
0769         if sys.version_info.major >= 3:
0770             self.assertEqual(repr(Row("数", "量")), "<Row('数', '量')>")
0771         else:
0772             self.assertEqual(repr(Row(u"数", u"量")), r"<Row(u'\u6570', u'\u91cf')>")
0773 
0774     def test_empty_row(self):
0775         row = Row()
0776         self.assertEqual(len(row), 0)
0777 
0778     def test_struct_field_type_name(self):
0779         struct_field = StructField("a", IntegerType())
0780         self.assertRaises(TypeError, struct_field.typeName)
0781 
0782     def test_invalid_create_row(self):
0783         row_class = Row("c1", "c2")
0784         self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
0785 
0786 
0787 class DataTypeVerificationTests(unittest.TestCase):
0788 
0789     def test_verify_type_exception_msg(self):
0790         self.assertRaisesRegexp(
0791             ValueError,
0792             "test_name",
0793             lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
0794 
0795         schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
0796         self.assertRaisesRegexp(
0797             TypeError,
0798             "field b in field a",
0799             lambda: _make_type_verifier(schema)([["data"]]))
0800 
0801     def test_verify_type_ok_nullable(self):
0802         obj = None
0803         types = [IntegerType(), FloatType(), StringType(), StructType([])]
0804         for data_type in types:
0805             try:
0806                 _make_type_verifier(data_type, nullable=True)(obj)
0807             except Exception:
0808                 self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))
0809 
0810     def test_verify_type_not_nullable(self):
0811         import array
0812         import datetime
0813         import decimal
0814 
0815         schema = StructType([
0816             StructField('s', StringType(), nullable=False),
0817             StructField('i', IntegerType(), nullable=True)])
0818 
0819         class MyObj:
0820             def __init__(self, **kwargs):
0821                 for k, v in kwargs.items():
0822                     setattr(self, k, v)
0823 
0824         # obj, data_type
0825         success_spec = [
0826             # String
0827             ("", StringType()),
0828             (u"", StringType()),
0829             (1, StringType()),
0830             (1.0, StringType()),
0831             ([], StringType()),
0832             ({}, StringType()),
0833 
0834             # UDT
0835             (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
0836 
0837             # Boolean
0838             (True, BooleanType()),
0839 
0840             # Byte
0841             (-(2**7), ByteType()),
0842             (2**7 - 1, ByteType()),
0843 
0844             # Short
0845             (-(2**15), ShortType()),
0846             (2**15 - 1, ShortType()),
0847 
0848             # Integer
0849             (-(2**31), IntegerType()),
0850             (2**31 - 1, IntegerType()),
0851 
0852             # Long
0853             (-(2**63), LongType()),
0854             (2**63 - 1, LongType()),
0855 
0856             # Float & Double
0857             (1.0, FloatType()),
0858             (1.0, DoubleType()),
0859 
0860             # Decimal
0861             (decimal.Decimal("1.0"), DecimalType()),
0862 
0863             # Binary
0864             (bytearray([1, 2]), BinaryType()),
0865 
0866             # Date/Timestamp
0867             (datetime.date(2000, 1, 2), DateType()),
0868             (datetime.datetime(2000, 1, 2, 3, 4), DateType()),
0869             (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
0870 
0871             # Array
0872             ([], ArrayType(IntegerType())),
0873             (["1", None], ArrayType(StringType(), containsNull=True)),
0874             ([1, 2], ArrayType(IntegerType())),
0875             ((1, 2), ArrayType(IntegerType())),
0876             (array.array('h', [1, 2]), ArrayType(IntegerType())),
0877 
0878             # Map
0879             ({}, MapType(StringType(), IntegerType())),
0880             ({"a": 1}, MapType(StringType(), IntegerType())),
0881             ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),
0882 
0883             # Struct
0884             ({"s": "a", "i": 1}, schema),
0885             ({"s": "a", "i": None}, schema),
0886             ({"s": "a"}, schema),
0887             ({"s": "a", "f": 1.0}, schema),
0888             (Row(s="a", i=1), schema),
0889             (Row(s="a", i=None), schema),
0890             (Row(s="a", i=1, f=1.0), schema),
0891             (["a", 1], schema),
0892             (["a", None], schema),
0893             (("a", 1), schema),
0894             (MyObj(s="a", i=1), schema),
0895             (MyObj(s="a", i=None), schema),
0896             (MyObj(s="a"), schema),
0897         ]
0898 
0899         # obj, data_type, exception class
0900         failure_spec = [
0901             # String (match anything but None)
0902             (None, StringType(), ValueError),
0903 
0904             # UDT
0905             (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
0906 
0907             # Boolean
0908             (1, BooleanType(), TypeError),
0909             ("True", BooleanType(), TypeError),
0910             ([1], BooleanType(), TypeError),
0911 
0912             # Byte
0913             (-(2**7) - 1, ByteType(), ValueError),
0914             (2**7, ByteType(), ValueError),
0915             ("1", ByteType(), TypeError),
0916             (1.0, ByteType(), TypeError),
0917 
0918             # Short
0919             (-(2**15) - 1, ShortType(), ValueError),
0920             (2**15, ShortType(), ValueError),
0921 
0922             # Integer
0923             (-(2**31) - 1, IntegerType(), ValueError),
0924             (2**31, IntegerType(), ValueError),
0925 
0926             # Float & Double
0927             (1, FloatType(), TypeError),
0928             (1, DoubleType(), TypeError),
0929 
0930             # Decimal
0931             (1.0, DecimalType(), TypeError),
0932             (1, DecimalType(), TypeError),
0933             ("1.0", DecimalType(), TypeError),
0934 
0935             # Binary
0936             (1, BinaryType(), TypeError),
0937 
0938             # Date/Timestamp
0939             ("2000-01-02", DateType(), TypeError),
0940             (946811040, TimestampType(), TypeError),
0941 
0942             # Array
0943             (["1", None], ArrayType(StringType(), containsNull=False), ValueError),
0944             ([1, "2"], ArrayType(IntegerType()), TypeError),
0945 
0946             # Map
0947             ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
0948             ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
0949             ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
0950              ValueError),
0951 
0952             # Struct
0953             ({"s": "a", "i": "1"}, schema, TypeError),
0954             (Row(s="a"), schema, ValueError),     # Row can't have missing field
0955             (Row(s="a", i="1"), schema, TypeError),
0956             (["a"], schema, ValueError),
0957             (["a", "1"], schema, TypeError),
0958             (MyObj(s="a", i="1"), schema, TypeError),
0959             (MyObj(s=None, i="1"), schema, ValueError),
0960         ]
0961 
0962         # Check success cases
0963         for obj, data_type in success_spec:
0964             try:
0965                 _make_type_verifier(data_type, nullable=False)(obj)
0966             except Exception:
0967                 self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))
0968 
0969         # Check failure cases
0970         for obj, data_type, exp in failure_spec:
0971             msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
0972             with self.assertRaises(exp, msg=msg):
0973                 _make_type_verifier(data_type, nullable=False)(obj)
0974 
0975     @unittest.skipIf(sys.version_info[:2] < (3, 6), "Create Row without sorting fields")
0976     def test_row_without_field_sorting(self):
0977         sorting_enabled_tmp = Row._row_field_sorting_enabled
0978         Row._row_field_sorting_enabled = False
0979 
0980         r = Row(b=1, a=2)
0981         TestRow = Row("b", "a")
0982         expected = TestRow(1, 2)
0983 
0984         self.assertEqual(r, expected)
0985         self.assertEqual(repr(r), "Row(b=1, a=2)")
0986         Row._row_field_sorting_enabled = sorting_enabled_tmp
0987 
0988 
0989 if __name__ == "__main__":
0990     from pyspark.sql.tests.test_types import *
0991 
0992     try:
0993         import xmlrunner
0994         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0995     except ImportError:
0996         testRunner = None
0997     unittest.main(testRunner=testRunner, verbosity=2)