0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019 import array
0020 import ctypes
0021 import datetime
0022 import os
0023 import pickle
0024 import sys
0025 import unittest
0026
0027 from pyspark.sql import Row
0028 from pyspark.sql.functions import col, UserDefinedFunction
0029 from pyspark.sql.types import *
0030 from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings, \
0031 _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type
0032 from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, PythonOnlyUDT, \
0033 ExamplePoint, PythonOnlyPoint, MyObject
0034
0035
0036 class TypesTests(ReusedSQLTestCase):
0037
0038 def test_apply_schema_to_row(self):
0039 df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
0040 df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
0041 self.assertEqual(df.collect(), df2.collect())
0042
0043 rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
0044 df3 = self.spark.createDataFrame(rdd, df.schema)
0045 self.assertEqual(10, df3.count())
0046
0047 def test_infer_schema_to_local(self):
0048 input = [{"a": 1}, {"b": "coffee"}]
0049 rdd = self.sc.parallelize(input)
0050 df = self.spark.createDataFrame(input)
0051 df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
0052 self.assertEqual(df.schema, df2.schema)
0053
0054 rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
0055 df3 = self.spark.createDataFrame(rdd, df.schema)
0056 self.assertEqual(10, df3.count())
0057
0058 def test_apply_schema_to_dict_and_rows(self):
0059 schema = StructType().add("b", StringType()).add("a", IntegerType())
0060 input = [{"a": 1}, {"b": "coffee"}]
0061 rdd = self.sc.parallelize(input)
0062 for verify in [False, True]:
0063 df = self.spark.createDataFrame(input, schema, verifySchema=verify)
0064 df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
0065 self.assertEqual(df.schema, df2.schema)
0066
0067 rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
0068 df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
0069 self.assertEqual(10, df3.count())
0070 input = [Row(a=x, b=str(x)) for x in range(10)]
0071 df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
0072 self.assertEqual(10, df4.count())
0073
0074 def test_create_dataframe_schema_mismatch(self):
0075 input = [Row(a=1)]
0076 rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
0077 schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
0078 df = self.spark.createDataFrame(rdd, schema)
0079 self.assertRaises(Exception, lambda: df.show())
0080
0081 def test_infer_schema(self):
0082 d = [Row(l=[], d={}, s=None),
0083 Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
0084 rdd = self.sc.parallelize(d)
0085 df = self.spark.createDataFrame(rdd)
0086 self.assertEqual([], df.rdd.map(lambda r: r.l).first())
0087 self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
0088
0089 with self.tempView("test"):
0090 df.createOrReplaceTempView("test")
0091 result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
0092 self.assertEqual(1, result.head()[0])
0093
0094 df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
0095 self.assertEqual(df.schema, df2.schema)
0096 self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
0097 self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
0098
0099 with self.tempView("test2"):
0100 df2.createOrReplaceTempView("test2")
0101 result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
0102 self.assertEqual(1, result.head()[0])
0103
0104 def test_infer_schema_specification(self):
0105 from decimal import Decimal
0106
0107 class A(object):
0108 def __init__(self):
0109 self.a = 1
0110
0111 data = [
0112 True,
0113 1,
0114 "a",
0115 u"a",
0116 datetime.date(1970, 1, 1),
0117 datetime.datetime(1970, 1, 1, 0, 0),
0118 1.0,
0119 array.array("d", [1]),
0120 [1],
0121 (1, ),
0122 {"a": 1},
0123 bytearray(1),
0124 Decimal(1),
0125 Row(a=1),
0126 Row("a")(1),
0127 A(),
0128 ]
0129
0130 df = self.spark.createDataFrame([data])
0131 actual = list(map(lambda x: x.dataType.simpleString(), df.schema))
0132 expected = [
0133 'boolean',
0134 'bigint',
0135 'string',
0136 'string',
0137 'date',
0138 'timestamp',
0139 'double',
0140 'array<double>',
0141 'array<bigint>',
0142 'struct<_1:bigint>',
0143 'map<string,bigint>',
0144 'binary',
0145 'decimal(38,18)',
0146 'struct<a:bigint>',
0147 'struct<a:bigint>',
0148 'struct<a:bigint>',
0149 ]
0150 self.assertEqual(actual, expected)
0151
0152 actual = list(df.first())
0153 expected = [
0154 True,
0155 1,
0156 'a',
0157 u"a",
0158 datetime.date(1970, 1, 1),
0159 datetime.datetime(1970, 1, 1, 0, 0),
0160 1.0,
0161 [1.0],
0162 [1],
0163 Row(_1=1),
0164 {"a": 1},
0165 bytearray(b'\x00'),
0166 Decimal('1.000000000000000000'),
0167 Row(a=1),
0168 Row(a=1),
0169 Row(a=1),
0170 ]
0171 self.assertEqual(actual, expected)
0172
0173 def test_infer_schema_not_enough_names(self):
0174 df = self.spark.createDataFrame([["a", "b"]], ["col1"])
0175 self.assertEqual(df.columns, ['col1', '_2'])
0176
0177 def test_infer_schema_fails(self):
0178 with self.assertRaisesRegexp(TypeError, 'field a'):
0179 self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
0180 schema=["a", "b"], samplingRatio=0.99)
0181
0182 def test_infer_nested_schema(self):
0183 NestedRow = Row("f1", "f2")
0184 nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
0185 NestedRow([2, 3], {"row2": 2.0})])
0186 df = self.spark.createDataFrame(nestedRdd1)
0187 self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
0188
0189 nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
0190 NestedRow([[2, 3], [3, 4]], [2, 3])])
0191 df = self.spark.createDataFrame(nestedRdd2)
0192 self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
0193
0194 from collections import namedtuple
0195 CustomRow = namedtuple('CustomRow', 'field1 field2')
0196 rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
0197 CustomRow(field1=2, field2="row2"),
0198 CustomRow(field1=3, field2="row3")])
0199 df = self.spark.createDataFrame(rdd)
0200 self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
0201
0202 def test_create_dataframe_from_dict_respects_schema(self):
0203 df = self.spark.createDataFrame([{'a': 1}], ["b"])
0204 self.assertEqual(df.columns, ['b'])
0205
0206 def test_negative_decimal(self):
0207 try:
0208 self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
0209 df = self.spark.createDataFrame([(1, ), (11, )], ["value"])
0210 ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
0211 actual = list(map(lambda r: int(r.value), ret))
0212 self.assertEqual(actual, [0, 10])
0213 finally:
0214 self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false")
0215
0216 def test_create_dataframe_from_objects(self):
0217 data = [MyObject(1, "1"), MyObject(2, "2")]
0218 df = self.spark.createDataFrame(data)
0219 self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
0220 self.assertEqual(df.first(), Row(key=1, value="1"))
0221
0222 def test_apply_schema(self):
0223 from datetime import date, datetime
0224 rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
0225 date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
0226 {"a": 1}, (2,), [1, 2, 3], None)])
0227 schema = StructType([
0228 StructField("byte1", ByteType(), False),
0229 StructField("byte2", ByteType(), False),
0230 StructField("short1", ShortType(), False),
0231 StructField("short2", ShortType(), False),
0232 StructField("int1", IntegerType(), False),
0233 StructField("float1", FloatType(), False),
0234 StructField("date1", DateType(), False),
0235 StructField("time1", TimestampType(), False),
0236 StructField("map1", MapType(StringType(), IntegerType(), False), False),
0237 StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
0238 StructField("list1", ArrayType(ByteType(), False), False),
0239 StructField("null1", DoubleType(), True)])
0240 df = self.spark.createDataFrame(rdd, schema)
0241 results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
0242 x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
0243 r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
0244 datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
0245 self.assertEqual(r, results.first())
0246
0247 with self.tempView("table2"):
0248 df.createOrReplaceTempView("table2")
0249 r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
0250 "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
0251 "float1 + 1.5 as float1 FROM table2").first()
0252
0253 self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
0254
0255 def test_convert_row_to_dict(self):
0256 row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
0257 self.assertEqual(1, row.asDict()['l'][0].a)
0258 df = self.sc.parallelize([row]).toDF()
0259
0260 with self.tempView("test"):
0261 df.createOrReplaceTempView("test")
0262 row = self.spark.sql("select l, d from test").head()
0263 self.assertEqual(1, row.asDict()["l"][0].a)
0264 self.assertEqual(1.0, row.asDict()['d']['key'].c)
0265
0266 def test_udt(self):
0267 from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
0268
0269 def check_datatype(datatype):
0270 pickled = pickle.loads(pickle.dumps(datatype))
0271 assert datatype == pickled
0272 scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
0273 python_datatype = _parse_datatype_json_string(scala_datatype.json())
0274 assert datatype == python_datatype
0275
0276 check_datatype(ExamplePointUDT())
0277 structtype_with_udt = StructType([StructField("label", DoubleType(), False),
0278 StructField("point", ExamplePointUDT(), False)])
0279 check_datatype(structtype_with_udt)
0280 p = ExamplePoint(1.0, 2.0)
0281 self.assertEqual(_infer_type(p), ExamplePointUDT())
0282 _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
0283 self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
0284
0285 check_datatype(PythonOnlyUDT())
0286 structtype_with_udt = StructType([StructField("label", DoubleType(), False),
0287 StructField("point", PythonOnlyUDT(), False)])
0288 check_datatype(structtype_with_udt)
0289 p = PythonOnlyPoint(1.0, 2.0)
0290 self.assertEqual(_infer_type(p), PythonOnlyUDT())
0291 _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
0292 self.assertRaises(
0293 ValueError,
0294 lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
0295
0296 def test_simple_udt_in_df(self):
0297 schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
0298 df = self.spark.createDataFrame(
0299 [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
0300 schema=schema)
0301 df.collect()
0302
0303 def test_nested_udt_in_df(self):
0304 schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
0305 df = self.spark.createDataFrame(
0306 [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
0307 schema=schema)
0308 df.collect()
0309
0310 schema = StructType().add("key", LongType()).add("val",
0311 MapType(LongType(), PythonOnlyUDT()))
0312 df = self.spark.createDataFrame(
0313 [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
0314 schema=schema)
0315 df.collect()
0316
0317 def test_complex_nested_udt_in_df(self):
0318 from pyspark.sql.functions import udf
0319
0320 schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
0321 df = self.spark.createDataFrame(
0322 [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
0323 schema=schema)
0324 df.collect()
0325
0326 gd = df.groupby("key").agg({"val": "collect_list"})
0327 gd.collect()
0328 udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
0329 gd.select(udf(*gd)).collect()
0330
0331 def test_udt_with_none(self):
0332 df = self.spark.range(0, 10, 1, 1)
0333
0334 def myudf(x):
0335 if x > 0:
0336 return PythonOnlyPoint(float(x), float(x))
0337
0338 self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
0339 rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
0340 self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
0341
0342 def test_infer_schema_with_udt(self):
0343 row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
0344 df = self.spark.createDataFrame([row])
0345 schema = df.schema
0346 field = [f for f in schema.fields if f.name == "point"][0]
0347 self.assertEqual(type(field.dataType), ExamplePointUDT)
0348
0349 with self.tempView("labeled_point"):
0350 df.createOrReplaceTempView("labeled_point")
0351 point = self.spark.sql("SELECT point FROM labeled_point").head().point
0352 self.assertEqual(point, ExamplePoint(1.0, 2.0))
0353
0354 row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
0355 df = self.spark.createDataFrame([row])
0356 schema = df.schema
0357 field = [f for f in schema.fields if f.name == "point"][0]
0358 self.assertEqual(type(field.dataType), PythonOnlyUDT)
0359
0360 with self.tempView("labeled_point"):
0361 df.createOrReplaceTempView("labeled_point")
0362 point = self.spark.sql("SELECT point FROM labeled_point").head().point
0363 self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
0364
0365 def test_apply_schema_with_udt(self):
0366 row = (1.0, ExamplePoint(1.0, 2.0))
0367 schema = StructType([StructField("label", DoubleType(), False),
0368 StructField("point", ExamplePointUDT(), False)])
0369 df = self.spark.createDataFrame([row], schema)
0370 point = df.head().point
0371 self.assertEqual(point, ExamplePoint(1.0, 2.0))
0372
0373 row = (1.0, PythonOnlyPoint(1.0, 2.0))
0374 schema = StructType([StructField("label", DoubleType(), False),
0375 StructField("point", PythonOnlyUDT(), False)])
0376 df = self.spark.createDataFrame([row], schema)
0377 point = df.head().point
0378 self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
0379
0380 def test_udf_with_udt(self):
0381 row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
0382 df = self.spark.createDataFrame([row])
0383 self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
0384 udf = UserDefinedFunction(lambda p: p.y, DoubleType())
0385 self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
0386 udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
0387 self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
0388
0389 row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
0390 df = self.spark.createDataFrame([row])
0391 self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
0392 udf = UserDefinedFunction(lambda p: p.y, DoubleType())
0393 self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
0394 udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
0395 self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
0396
0397 def test_parquet_with_udt(self):
0398 row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
0399 df0 = self.spark.createDataFrame([row])
0400 output_dir = os.path.join(self.tempdir.name, "labeled_point")
0401 df0.write.parquet(output_dir)
0402 df1 = self.spark.read.parquet(output_dir)
0403 point = df1.head().point
0404 self.assertEqual(point, ExamplePoint(1.0, 2.0))
0405
0406 row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
0407 df0 = self.spark.createDataFrame([row])
0408 df0.write.parquet(output_dir, mode='overwrite')
0409 df1 = self.spark.read.parquet(output_dir)
0410 point = df1.head().point
0411 self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
0412
0413 def test_union_with_udt(self):
0414 row1 = (1.0, ExamplePoint(1.0, 2.0))
0415 row2 = (2.0, ExamplePoint(3.0, 4.0))
0416 schema = StructType([StructField("label", DoubleType(), False),
0417 StructField("point", ExamplePointUDT(), False)])
0418 df1 = self.spark.createDataFrame([row1], schema)
0419 df2 = self.spark.createDataFrame([row2], schema)
0420
0421 result = df1.union(df2).orderBy("label").collect()
0422 self.assertEqual(
0423 result,
0424 [
0425 Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
0426 Row(label=2.0, point=ExamplePoint(3.0, 4.0))
0427 ]
0428 )
0429
0430 def test_cast_to_string_with_udt(self):
0431 from pyspark.sql.functions import col
0432 row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
0433 schema = StructType([StructField("point", ExamplePointUDT(), False),
0434 StructField("pypoint", PythonOnlyUDT(), False)])
0435 df = self.spark.createDataFrame([row], schema)
0436
0437 result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
0438 self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
0439
0440 def test_struct_type(self):
0441 struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0442 struct2 = StructType([StructField("f1", StringType(), True),
0443 StructField("f2", StringType(), True, None)])
0444 self.assertEqual(struct1.fieldNames(), struct2.names)
0445 self.assertEqual(struct1, struct2)
0446
0447 struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0448 struct2 = StructType([StructField("f1", StringType(), True)])
0449 self.assertNotEqual(struct1.fieldNames(), struct2.names)
0450 self.assertNotEqual(struct1, struct2)
0451
0452 struct1 = (StructType().add(StructField("f1", StringType(), True))
0453 .add(StructField("f2", StringType(), True, None)))
0454 struct2 = StructType([StructField("f1", StringType(), True),
0455 StructField("f2", StringType(), True, None)])
0456 self.assertEqual(struct1.fieldNames(), struct2.names)
0457 self.assertEqual(struct1, struct2)
0458
0459 struct1 = (StructType().add(StructField("f1", StringType(), True))
0460 .add(StructField("f2", StringType(), True, None)))
0461 struct2 = StructType([StructField("f1", StringType(), True)])
0462 self.assertNotEqual(struct1.fieldNames(), struct2.names)
0463 self.assertNotEqual(struct1, struct2)
0464
0465
0466 self.assertRaises(ValueError, lambda: StructType().add("name"))
0467
0468 struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0469 for field in struct1:
0470 self.assertIsInstance(field, StructField)
0471
0472 struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0473 self.assertEqual(len(struct1), 2)
0474
0475 struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0476 self.assertIs(struct1["f1"], struct1.fields[0])
0477 self.assertIs(struct1[0], struct1.fields[0])
0478 self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
0479 self.assertRaises(KeyError, lambda: struct1["f9"])
0480 self.assertRaises(IndexError, lambda: struct1[9])
0481 self.assertRaises(TypeError, lambda: struct1[9.9])
0482
0483 def test_parse_datatype_string(self):
0484 from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
0485 for k, t in _all_atomic_types.items():
0486 if t != NullType:
0487 self.assertEqual(t(), _parse_datatype_string(k))
0488 self.assertEqual(IntegerType(), _parse_datatype_string("int"))
0489 self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
0490 self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
0491 self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
0492 self.assertEqual(
0493 ArrayType(IntegerType()),
0494 _parse_datatype_string("array<int >"))
0495 self.assertEqual(
0496 MapType(IntegerType(), DoubleType()),
0497 _parse_datatype_string("map< int, double >"))
0498 self.assertEqual(
0499 StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
0500 _parse_datatype_string("struct<a:int, c:double >"))
0501 self.assertEqual(
0502 StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
0503 _parse_datatype_string("a:int, c:double"))
0504 self.assertEqual(
0505 StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
0506 _parse_datatype_string("a INT, c DOUBLE"))
0507
0508 def test_metadata_null(self):
0509 schema = StructType([StructField("f1", StringType(), True, None),
0510 StructField("f2", StringType(), True, {'a': None})])
0511 rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
0512 self.spark.createDataFrame(rdd, schema)
0513
0514 def test_access_nested_types(self):
0515 df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
0516 self.assertEqual(1, df.select(df.l[0]).first()[0])
0517 self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
0518 self.assertEqual(1, df.select(df.r.a).first()[0])
0519 self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
0520 self.assertEqual("v", df.select(df.d["k"]).first()[0])
0521 self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
0522
0523 def test_infer_long_type(self):
0524 longrow = [Row(f1='a', f2=100000000000000)]
0525 df = self.sc.parallelize(longrow).toDF()
0526 self.assertEqual(df.schema.fields[1].dataType, LongType())
0527
0528
0529 output_dir = os.path.join(self.tempdir.name, "infer_long_type")
0530 df.write.parquet(output_dir)
0531 df1 = self.spark.read.parquet(output_dir)
0532 self.assertEqual('a', df1.first().f1)
0533 self.assertEqual(100000000000000, df1.first().f2)
0534
0535 self.assertEqual(_infer_type(1), LongType())
0536 self.assertEqual(_infer_type(2**10), LongType())
0537 self.assertEqual(_infer_type(2**20), LongType())
0538 self.assertEqual(_infer_type(2**31 - 1), LongType())
0539 self.assertEqual(_infer_type(2**31), LongType())
0540 self.assertEqual(_infer_type(2**61), LongType())
0541 self.assertEqual(_infer_type(2**71), LongType())
0542
0543 @unittest.skipIf(sys.version < "3", "only Python 3 infers bytes as binary type")
0544 def test_infer_binary_type(self):
0545 binaryrow = [Row(f1='a', f2=b"abcd")]
0546 df = self.sc.parallelize(binaryrow).toDF()
0547 self.assertEqual(df.schema.fields[1].dataType, BinaryType())
0548
0549
0550 output_dir = os.path.join(self.tempdir.name, "infer_binary_type")
0551 df.write.parquet(output_dir)
0552 df1 = self.spark.read.parquet(output_dir)
0553 self.assertEqual('a', df1.first().f1)
0554 self.assertEqual(b"abcd", df1.first().f2)
0555
0556 self.assertEqual(_infer_type(b""), BinaryType())
0557 self.assertEqual(_infer_type(b"1234"), BinaryType())
0558
0559 def test_merge_type(self):
0560 self.assertEqual(_merge_type(LongType(), NullType()), LongType())
0561 self.assertEqual(_merge_type(NullType(), LongType()), LongType())
0562
0563 self.assertEqual(_merge_type(LongType(), LongType()), LongType())
0564
0565 self.assertEqual(_merge_type(
0566 ArrayType(LongType()),
0567 ArrayType(LongType())
0568 ), ArrayType(LongType()))
0569 with self.assertRaisesRegexp(TypeError, 'element in array'):
0570 _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
0571
0572 self.assertEqual(_merge_type(
0573 MapType(StringType(), LongType()),
0574 MapType(StringType(), LongType())
0575 ), MapType(StringType(), LongType()))
0576 with self.assertRaisesRegexp(TypeError, 'key of map'):
0577 _merge_type(
0578 MapType(StringType(), LongType()),
0579 MapType(DoubleType(), LongType()))
0580 with self.assertRaisesRegexp(TypeError, 'value of map'):
0581 _merge_type(
0582 MapType(StringType(), LongType()),
0583 MapType(StringType(), DoubleType()))
0584
0585 self.assertEqual(_merge_type(
0586 StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
0587 StructType([StructField("f1", LongType()), StructField("f2", StringType())])
0588 ), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
0589 with self.assertRaisesRegexp(TypeError, 'field f1'):
0590 _merge_type(
0591 StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
0592 StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
0593
0594 self.assertEqual(_merge_type(
0595 StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
0596 StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
0597 ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
0598 with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
0599 _merge_type(
0600 StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
0601 StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
0602
0603 self.assertEqual(_merge_type(
0604 StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
0605 StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
0606 ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
0607 with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
0608 _merge_type(
0609 StructType([
0610 StructField("f1", ArrayType(LongType())),
0611 StructField("f2", StringType())]),
0612 StructType([
0613 StructField("f1", ArrayType(DoubleType())),
0614 StructField("f2", StringType())]))
0615
0616 self.assertEqual(_merge_type(
0617 StructType([
0618 StructField("f1", MapType(StringType(), LongType())),
0619 StructField("f2", StringType())]),
0620 StructType([
0621 StructField("f1", MapType(StringType(), LongType())),
0622 StructField("f2", StringType())])
0623 ), StructType([
0624 StructField("f1", MapType(StringType(), LongType())),
0625 StructField("f2", StringType())]))
0626 with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
0627 _merge_type(
0628 StructType([
0629 StructField("f1", MapType(StringType(), LongType())),
0630 StructField("f2", StringType())]),
0631 StructType([
0632 StructField("f1", MapType(StringType(), DoubleType())),
0633 StructField("f2", StringType())]))
0634
0635 self.assertEqual(_merge_type(
0636 StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
0637 StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
0638 ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
0639 with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
0640 _merge_type(
0641 StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
0642 StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
0643 )
0644
0645
0646 def test_array_types(self):
0647
0648
0649
0650
0651
0652
0653
0654 def assertCollectSuccess(typecode, value):
0655 row = Row(myarray=array.array(typecode, [value]))
0656 df = self.spark.createDataFrame([row])
0657 self.assertEqual(df.first()["myarray"][0], value)
0658
0659
0660
0661
0662
0663 supported_string_types = []
0664 if sys.version_info[0] < 4:
0665 supported_string_types += ['u']
0666
0667 assertCollectSuccess('u', u'a')
0668 if sys.version_info[0] < 3:
0669 supported_string_types += ['c']
0670
0671 assertCollectSuccess('c', 'a')
0672
0673
0674
0675
0676
0677 supported_fractional_types = ['f', 'd']
0678 assertCollectSuccess('f', ctypes.c_float(1e+38).value)
0679 assertCollectSuccess('f', ctypes.c_float(1e-38).value)
0680 assertCollectSuccess('f', ctypes.c_float(1.123456).value)
0681 assertCollectSuccess('d', sys.float_info.max)
0682 assertCollectSuccess('d', sys.float_info.min)
0683 assertCollectSuccess('d', sys.float_info.epsilon)
0684
0685
0686
0687
0688
0689 supported_signed_int_types = list(
0690 set(_array_signed_int_typecode_ctype_mappings.keys())
0691 .intersection(set(_array_type_mappings.keys())))
0692 for t in supported_signed_int_types:
0693 ctype = _array_signed_int_typecode_ctype_mappings[t]
0694 max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
0695 assertCollectSuccess(t, max_val - 1)
0696 assertCollectSuccess(t, -max_val)
0697
0698
0699
0700
0701
0702 supported_unsigned_int_types = list(
0703 set(_array_unsigned_int_typecode_ctype_mappings.keys())
0704 .intersection(set(_array_type_mappings.keys())))
0705 for t in supported_unsigned_int_types:
0706 ctype = _array_unsigned_int_typecode_ctype_mappings[t]
0707 assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
0708
0709
0710
0711
0712
0713
0714 supported_types = (supported_string_types +
0715 supported_fractional_types +
0716 supported_signed_int_types +
0717 supported_unsigned_int_types)
0718 self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
0719
0720
0721
0722
0723
0724
0725 if sys.version_info[0] < 3:
0726 all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd'])
0727 else:
0728 all_types = set(array.typecodes)
0729 unsupported_types = all_types - set(supported_types)
0730
0731 for t in unsupported_types:
0732 with self.assertRaises(TypeError):
0733 a = array.array(t)
0734 self.spark.createDataFrame([Row(myarray=a)]).collect()
0735
0736
0737 class DataTypeTests(unittest.TestCase):
0738
0739 def test_data_type_eq(self):
0740 lt = LongType()
0741 lt2 = pickle.loads(pickle.dumps(LongType()))
0742 self.assertEqual(lt, lt2)
0743
0744
0745 def test_decimal_type(self):
0746 t1 = DecimalType()
0747 t2 = DecimalType(10, 2)
0748 self.assertTrue(t2 is not t1)
0749 self.assertNotEqual(t1, t2)
0750 t3 = DecimalType(8)
0751 self.assertNotEqual(t2, t3)
0752
0753
0754 def test_datetype_equal_zero(self):
0755 dt = DateType()
0756 self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
0757
0758
0759 def test_timestamp_microsecond(self):
0760 tst = TimestampType()
0761 self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999)
0762
0763
0764 def test_row_without_column_name(self):
0765 row = Row("Alice", 11)
0766 self.assertEqual(repr(row), "<Row('Alice', 11)>")
0767
0768
0769 if sys.version_info.major >= 3:
0770 self.assertEqual(repr(Row("数", "量")), "<Row('数', '量')>")
0771 else:
0772 self.assertEqual(repr(Row(u"数", u"量")), r"<Row(u'\u6570', u'\u91cf')>")
0773
0774 def test_empty_row(self):
0775 row = Row()
0776 self.assertEqual(len(row), 0)
0777
0778 def test_struct_field_type_name(self):
0779 struct_field = StructField("a", IntegerType())
0780 self.assertRaises(TypeError, struct_field.typeName)
0781
0782 def test_invalid_create_row(self):
0783 row_class = Row("c1", "c2")
0784 self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
0785
0786
0787 class DataTypeVerificationTests(unittest.TestCase):
0788
0789 def test_verify_type_exception_msg(self):
0790 self.assertRaisesRegexp(
0791 ValueError,
0792 "test_name",
0793 lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
0794
0795 schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
0796 self.assertRaisesRegexp(
0797 TypeError,
0798 "field b in field a",
0799 lambda: _make_type_verifier(schema)([["data"]]))
0800
0801 def test_verify_type_ok_nullable(self):
0802 obj = None
0803 types = [IntegerType(), FloatType(), StringType(), StructType([])]
0804 for data_type in types:
0805 try:
0806 _make_type_verifier(data_type, nullable=True)(obj)
0807 except Exception:
0808 self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))
0809
0810 def test_verify_type_not_nullable(self):
0811 import array
0812 import datetime
0813 import decimal
0814
0815 schema = StructType([
0816 StructField('s', StringType(), nullable=False),
0817 StructField('i', IntegerType(), nullable=True)])
0818
0819 class MyObj:
0820 def __init__(self, **kwargs):
0821 for k, v in kwargs.items():
0822 setattr(self, k, v)
0823
0824
0825 success_spec = [
0826
0827 ("", StringType()),
0828 (u"", StringType()),
0829 (1, StringType()),
0830 (1.0, StringType()),
0831 ([], StringType()),
0832 ({}, StringType()),
0833
0834
0835 (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
0836
0837
0838 (True, BooleanType()),
0839
0840
0841 (-(2**7), ByteType()),
0842 (2**7 - 1, ByteType()),
0843
0844
0845 (-(2**15), ShortType()),
0846 (2**15 - 1, ShortType()),
0847
0848
0849 (-(2**31), IntegerType()),
0850 (2**31 - 1, IntegerType()),
0851
0852
0853 (-(2**63), LongType()),
0854 (2**63 - 1, LongType()),
0855
0856
0857 (1.0, FloatType()),
0858 (1.0, DoubleType()),
0859
0860
0861 (decimal.Decimal("1.0"), DecimalType()),
0862
0863
0864 (bytearray([1, 2]), BinaryType()),
0865
0866
0867 (datetime.date(2000, 1, 2), DateType()),
0868 (datetime.datetime(2000, 1, 2, 3, 4), DateType()),
0869 (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
0870
0871
0872 ([], ArrayType(IntegerType())),
0873 (["1", None], ArrayType(StringType(), containsNull=True)),
0874 ([1, 2], ArrayType(IntegerType())),
0875 ((1, 2), ArrayType(IntegerType())),
0876 (array.array('h', [1, 2]), ArrayType(IntegerType())),
0877
0878
0879 ({}, MapType(StringType(), IntegerType())),
0880 ({"a": 1}, MapType(StringType(), IntegerType())),
0881 ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),
0882
0883
0884 ({"s": "a", "i": 1}, schema),
0885 ({"s": "a", "i": None}, schema),
0886 ({"s": "a"}, schema),
0887 ({"s": "a", "f": 1.0}, schema),
0888 (Row(s="a", i=1), schema),
0889 (Row(s="a", i=None), schema),
0890 (Row(s="a", i=1, f=1.0), schema),
0891 (["a", 1], schema),
0892 (["a", None], schema),
0893 (("a", 1), schema),
0894 (MyObj(s="a", i=1), schema),
0895 (MyObj(s="a", i=None), schema),
0896 (MyObj(s="a"), schema),
0897 ]
0898
0899
0900 failure_spec = [
0901
0902 (None, StringType(), ValueError),
0903
0904
0905 (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
0906
0907
0908 (1, BooleanType(), TypeError),
0909 ("True", BooleanType(), TypeError),
0910 ([1], BooleanType(), TypeError),
0911
0912
0913 (-(2**7) - 1, ByteType(), ValueError),
0914 (2**7, ByteType(), ValueError),
0915 ("1", ByteType(), TypeError),
0916 (1.0, ByteType(), TypeError),
0917
0918
0919 (-(2**15) - 1, ShortType(), ValueError),
0920 (2**15, ShortType(), ValueError),
0921
0922
0923 (-(2**31) - 1, IntegerType(), ValueError),
0924 (2**31, IntegerType(), ValueError),
0925
0926
0927 (1, FloatType(), TypeError),
0928 (1, DoubleType(), TypeError),
0929
0930
0931 (1.0, DecimalType(), TypeError),
0932 (1, DecimalType(), TypeError),
0933 ("1.0", DecimalType(), TypeError),
0934
0935
0936 (1, BinaryType(), TypeError),
0937
0938
0939 ("2000-01-02", DateType(), TypeError),
0940 (946811040, TimestampType(), TypeError),
0941
0942
0943 (["1", None], ArrayType(StringType(), containsNull=False), ValueError),
0944 ([1, "2"], ArrayType(IntegerType()), TypeError),
0945
0946
0947 ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
0948 ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
0949 ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
0950 ValueError),
0951
0952
0953 ({"s": "a", "i": "1"}, schema, TypeError),
0954 (Row(s="a"), schema, ValueError),
0955 (Row(s="a", i="1"), schema, TypeError),
0956 (["a"], schema, ValueError),
0957 (["a", "1"], schema, TypeError),
0958 (MyObj(s="a", i="1"), schema, TypeError),
0959 (MyObj(s=None, i="1"), schema, ValueError),
0960 ]
0961
0962
0963 for obj, data_type in success_spec:
0964 try:
0965 _make_type_verifier(data_type, nullable=False)(obj)
0966 except Exception:
0967 self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))
0968
0969
0970 for obj, data_type, exp in failure_spec:
0971 msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
0972 with self.assertRaises(exp, msg=msg):
0973 _make_type_verifier(data_type, nullable=False)(obj)
0974
0975 @unittest.skipIf(sys.version_info[:2] < (3, 6), "Create Row without sorting fields")
0976 def test_row_without_field_sorting(self):
0977 sorting_enabled_tmp = Row._row_field_sorting_enabled
0978 Row._row_field_sorting_enabled = False
0979
0980 r = Row(b=1, a=2)
0981 TestRow = Row("b", "a")
0982 expected = TestRow(1, 2)
0983
0984 self.assertEqual(r, expected)
0985 self.assertEqual(repr(r), "Row(b=1, a=2)")
0986 Row._row_field_sorting_enabled = sorting_enabled_tmp
0987
0988
0989 if __name__ == "__main__":
0990 from pyspark.sql.tests.test_types import *
0991
0992 try:
0993 import xmlrunner
0994 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0995 except ImportError:
0996 testRunner = None
0997 unittest.main(testRunner=testRunner, verbosity=2)