0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import pydoc
0020 import time
0021 import unittest
0022
0023 from pyspark.sql import SparkSession, Row
0024 from pyspark.sql.types import *
0025 from pyspark.sql.utils import AnalysisException, IllegalArgumentException
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \
0027 pandas_requirement_message, pyarrow_requirement_message
0028 from pyspark.testing.utils import QuietTest
0029
0030
0031 class DataFrameTests(ReusedSQLTestCase):
0032
0033 def test_range(self):
0034 self.assertEqual(self.spark.range(1, 1).count(), 0)
0035 self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
0036 self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
0037 self.assertEqual(self.spark.range(-2).count(), 0)
0038 self.assertEqual(self.spark.range(3).count(), 3)
0039
0040 def test_duplicated_column_names(self):
0041 df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
0042 row = df.select('*').first()
0043 self.assertEqual(1, row[0])
0044 self.assertEqual(2, row[1])
0045 self.assertEqual("Row(c=1, c=2)", str(row))
0046
0047 self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
0048 self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
0049 self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
0050
0051 def test_freqItems(self):
0052 vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
0053 df = self.sc.parallelize(vals).toDF()
0054 items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
0055 self.assertTrue(1 in items[0])
0056 self.assertTrue(-2.0 in items[1])
0057
0058 def test_help_command(self):
0059
0060 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
0061 df = self.spark.read.json(rdd)
0062
0063 pydoc.render_doc(df)
0064 pydoc.render_doc(df.foo)
0065 pydoc.render_doc(df.take(1))
0066
0067 def test_dropna(self):
0068 schema = StructType([
0069 StructField("name", StringType(), True),
0070 StructField("age", IntegerType(), True),
0071 StructField("height", DoubleType(), True)])
0072
0073
0074 self.assertEqual(self.spark.createDataFrame(
0075 [(u'Alice', 50, 80.1)], schema).dropna().count(),
0076 1)
0077
0078
0079 self.assertEqual(self.spark.createDataFrame(
0080 [(u'Alice', None, 80.1)], schema).dropna().count(),
0081 0)
0082 self.assertEqual(self.spark.createDataFrame(
0083 [(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
0084 0)
0085
0086
0087 self.assertEqual(self.spark.createDataFrame(
0088 [(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
0089 1)
0090 self.assertEqual(self.spark.createDataFrame(
0091 [(None, None, None)], schema).dropna(how='all').count(),
0092 0)
0093
0094
0095 self.assertEqual(self.spark.createDataFrame(
0096 [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
0097 1)
0098 self.assertEqual(self.spark.createDataFrame(
0099 [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
0100 0)
0101
0102
0103 self.assertEqual(self.spark.createDataFrame(
0104 [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
0105 1)
0106 self.assertEqual(self.spark.createDataFrame(
0107 [(u'Alice', None, None)], schema).dropna(thresh=2).count(),
0108 0)
0109
0110
0111 self.assertEqual(self.spark.createDataFrame(
0112 [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
0113 1)
0114 self.assertEqual(self.spark.createDataFrame(
0115 [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
0116 0)
0117
0118
0119 self.assertEqual(self.spark.createDataFrame(
0120 [(u'Alice', 50, None)], schema).dropna(
0121 how='any', thresh=2, subset=['name', 'age']).count(),
0122 1)
0123
0124 def test_fillna(self):
0125 schema = StructType([
0126 StructField("name", StringType(), True),
0127 StructField("age", IntegerType(), True),
0128 StructField("height", DoubleType(), True),
0129 StructField("spy", BooleanType(), True)])
0130
0131
0132 row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
0133 self.assertEqual(row.age, 10)
0134
0135
0136 row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
0137 self.assertEqual(row.age, 50)
0138 self.assertEqual(row.height, 50.0)
0139
0140
0141 row = self.spark.createDataFrame(
0142 [(u'Alice', None, None, None)], schema).fillna(50.1).first()
0143 self.assertEqual(row.age, 50)
0144 self.assertEqual(row.height, 50.1)
0145
0146
0147 row = self.spark.createDataFrame(
0148 [(u'Alice', None, None, None)], schema).fillna(True).first()
0149 self.assertEqual(row.age, None)
0150 self.assertEqual(row.spy, True)
0151
0152
0153 row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
0154 self.assertEqual(row.name, u"hello")
0155 self.assertEqual(row.age, None)
0156
0157
0158 row = self.spark.createDataFrame(
0159 [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
0160 self.assertEqual(row.name, None)
0161 self.assertEqual(row.age, 50)
0162 self.assertEqual(row.height, None)
0163 self.assertEqual(row.spy, None)
0164
0165
0166 row = self.spark.createDataFrame(
0167 [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
0168 self.assertEqual(row.name, "haha")
0169 self.assertEqual(row.age, None)
0170 self.assertEqual(row.height, None)
0171 self.assertEqual(row.spy, None)
0172
0173
0174 row = self.spark.createDataFrame(
0175 [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
0176 self.assertEqual(row.name, None)
0177 self.assertEqual(row.age, None)
0178 self.assertEqual(row.height, None)
0179 self.assertEqual(row.spy, True)
0180
0181
0182 row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
0183 self.assertEqual(row.a, True)
0184
0185 def test_repartitionByRange_dataframe(self):
0186 schema = StructType([
0187 StructField("name", StringType(), True),
0188 StructField("age", IntegerType(), True),
0189 StructField("height", DoubleType(), True)])
0190
0191 df1 = self.spark.createDataFrame(
0192 [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
0193 df2 = self.spark.createDataFrame(
0194 [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)
0195
0196
0197 df3 = df1.repartitionByRange(2, "name", "age")
0198 self.assertEqual(df3.rdd.getNumPartitions(), 2)
0199 self.assertEqual(df3.rdd.first(), df2.rdd.first())
0200 self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
0201
0202
0203 df4 = df1.repartitionByRange(3, "name", "age")
0204 self.assertEqual(df4.rdd.getNumPartitions(), 3)
0205 self.assertEqual(df4.rdd.first(), df2.rdd.first())
0206 self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
0207
0208
0209 df5 = df1.repartitionByRange("name", "age")
0210 self.assertEqual(df5.rdd.first(), df2.rdd.first())
0211 self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
0212
0213 def test_replace(self):
0214 schema = StructType([
0215 StructField("name", StringType(), True),
0216 StructField("age", IntegerType(), True),
0217 StructField("height", DoubleType(), True)])
0218
0219
0220 row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
0221 self.assertEqual(row.age, 20)
0222 self.assertEqual(row.height, 20.0)
0223
0224
0225 row = self.spark.createDataFrame(
0226 [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
0227 self.assertEqual(row.age, 82)
0228 self.assertEqual(row.height, 82.1)
0229
0230
0231 row = self.spark.createDataFrame(
0232 [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
0233 self.assertEqual(row.name, u"Ann")
0234 self.assertEqual(row.age, 10)
0235
0236
0237 row = self.spark.createDataFrame(
0238 [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
0239 self.assertEqual(row.age, 20)
0240
0241
0242 row = self.spark.createDataFrame(
0243 [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
0244 self.assertEqual(row.age, 10)
0245
0246
0247
0248 row = self.spark.createDataFrame(
0249 [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
0250 self.assertEqual(row.name, u'Alice')
0251 self.assertEqual(row.age, 20)
0252 self.assertEqual(row.height, 10.0)
0253
0254
0255 row = self.spark.createDataFrame(
0256 [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
0257 self.assertEqual(row.name, u'Alice')
0258 self.assertEqual(row.age, 10)
0259 self.assertEqual(row.height, None)
0260
0261
0262 row = self.spark.createDataFrame(
0263 [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
0264 self.assertTupleEqual(row, (u'Ann', 10, 80.1))
0265
0266
0267 row = self.spark.createDataFrame(
0268 [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
0269 self.assertTupleEqual(row, (u'Alice', 11, 80.1))
0270
0271
0272 dummy_value = 1
0273 row = self.spark.createDataFrame(
0274 [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
0275 self.assertTupleEqual(row, (u'Bob', 10, 80.1))
0276
0277
0278 row = self.spark.createDataFrame(
0279 [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
0280 self.assertTupleEqual(row, (u'Alice', -10, 90.5))
0281
0282
0283 row = self.spark.createDataFrame(
0284 [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
0285 self.assertTupleEqual(row, (u'Bob', 10, 80.1))
0286
0287
0288 row = self.spark.createDataFrame(
0289 [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
0290 self.assertTupleEqual(row, (u'Alice', 20, 90.0))
0291
0292
0293 row = self.spark.createDataFrame(
0294 [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
0295 self.assertTupleEqual(row, (u'Alice', 20, 90.5))
0296
0297 row = self.spark.createDataFrame(
0298 [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
0299 self.assertTupleEqual(row, (u'Alice', 20, 90.5))
0300
0301
0302 row = (self
0303 .spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
0304 .selectExpr("name = 'Bob'", 'age <= 15')
0305 .replace(False, True).first())
0306 self.assertTupleEqual(row, (True, True))
0307
0308
0309 row = self.spark.createDataFrame(
0310 [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
0311 self.assertEqual(row.count(), 0)
0312
0313
0314 row = self.spark.createDataFrame(
0315 [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
0316 self.assertTupleEqual(row, (u'Alice', 20, None))
0317
0318
0319 with self.assertRaises(ValueError):
0320 self.spark.createDataFrame(
0321 [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
0322
0323
0324 with self.assertRaises(ValueError):
0325 self.spark.createDataFrame(
0326 [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
0327
0328
0329 with self.assertRaises(ValueError):
0330 from datetime import datetime
0331 self.spark.createDataFrame(
0332 [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
0333
0334
0335 with self.assertRaises(ValueError):
0336 self.spark.createDataFrame(
0337 [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
0338
0339 with self.assertRaises(ValueError):
0340 self.spark.createDataFrame(
0341 [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
0342
0343 with self.assertRaisesRegexp(
0344 TypeError,
0345 'value argument is required when to_replace is not a dictionary.'):
0346 self.spark.createDataFrame(
0347 [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
0348
0349 def test_with_column_with_existing_name(self):
0350 keys = self.df.withColumn("key", self.df.key).select("key").collect()
0351 self.assertEqual([r.key for r in keys], list(range(100)))
0352
0353
0354 def test_column_iterator(self):
0355
0356 def foo():
0357 for x in self.df.key:
0358 break
0359
0360 self.assertRaises(TypeError, foo)
0361
0362 def test_generic_hints(self):
0363 from pyspark.sql import DataFrame
0364
0365 df1 = self.spark.range(10e10).toDF("id")
0366 df2 = self.spark.range(10e10).toDF("id")
0367
0368 self.assertIsInstance(df1.hint("broadcast"), DataFrame)
0369 self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
0370
0371
0372 self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
0373 self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
0374
0375 plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
0376 self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
0377
0378
0379 def test_extended_hint_types(self):
0380 from pyspark.sql import DataFrame
0381
0382 df = self.spark.range(10e10).toDF("id")
0383 such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
0384 hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list)
0385 logical_plan = hinted_df._jdf.queryExecution().logical()
0386
0387 self.assertEqual(1, logical_plan.toString().count("1.2345"))
0388 self.assertEqual(1, logical_plan.toString().count("what"))
0389 self.assertEqual(3, logical_plan.toString().count("itworks"))
0390
0391 def test_sample(self):
0392 self.assertRaisesRegexp(
0393 TypeError,
0394 "should be a bool, float and number",
0395 lambda: self.spark.range(1).sample())
0396
0397 self.assertRaises(
0398 TypeError,
0399 lambda: self.spark.range(1).sample("a"))
0400
0401 self.assertRaises(
0402 TypeError,
0403 lambda: self.spark.range(1).sample(seed="abc"))
0404
0405 self.assertRaises(
0406 IllegalArgumentException,
0407 lambda: self.spark.range(1).sample(-1.0))
0408
0409 def test_toDF_with_schema_string(self):
0410 data = [Row(key=i, value=str(i)) for i in range(100)]
0411 rdd = self.sc.parallelize(data, 5)
0412
0413 df = rdd.toDF("key: int, value: string")
0414 self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
0415 self.assertEqual(df.collect(), data)
0416
0417
0418 df = rdd.toDF("key: string, value: string")
0419 self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
0420 self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
0421
0422
0423 df = rdd.toDF(" a: int, b: string ")
0424 self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
0425 self.assertEqual(df.collect(), data)
0426
0427
0428 self.assertRaisesRegexp(Exception, "Length of object",
0429 lambda: rdd.toDF("key: int").collect())
0430
0431
0432 self.assertRaisesRegexp(Exception, "FloatType can not accept",
0433 lambda: rdd.toDF("key: float, value: string").collect())
0434
0435
0436 df = rdd.map(lambda row: row.key).toDF("int")
0437 self.assertEqual(df.schema.simpleString(), "struct<value:int>")
0438 self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
0439
0440
0441 df = rdd.map(lambda row: row.key).toDF(IntegerType())
0442 self.assertEqual(df.schema.simpleString(), "struct<value:int>")
0443 self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
0444
0445 def test_join_without_on(self):
0446 df1 = self.spark.range(1).toDF("a")
0447 df2 = self.spark.range(1).toDF("b")
0448
0449 with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
0450 self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
0451
0452 with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
0453 actual = df1.join(df2, how="inner").collect()
0454 expected = [Row(a=0, b=0)]
0455 self.assertEqual(actual, expected)
0456
0457
0458 def test_invalid_join_method(self):
0459 df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
0460 df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
0461 self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
0462
0463
0464 def test_require_cross(self):
0465
0466 df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
0467 df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
0468
0469 with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
0470
0471 self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
0472
0473
0474 self.assertEqual(1, df1.crossJoin(df2).count())
0475
0476 def test_cache(self):
0477 spark = self.spark
0478 with self.tempView("tab1", "tab2"):
0479 spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
0480 spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
0481 self.assertFalse(spark.catalog.isCached("tab1"))
0482 self.assertFalse(spark.catalog.isCached("tab2"))
0483 spark.catalog.cacheTable("tab1")
0484 self.assertTrue(spark.catalog.isCached("tab1"))
0485 self.assertFalse(spark.catalog.isCached("tab2"))
0486 spark.catalog.cacheTable("tab2")
0487 spark.catalog.uncacheTable("tab1")
0488 self.assertFalse(spark.catalog.isCached("tab1"))
0489 self.assertTrue(spark.catalog.isCached("tab2"))
0490 spark.catalog.clearCache()
0491 self.assertFalse(spark.catalog.isCached("tab1"))
0492 self.assertFalse(spark.catalog.isCached("tab2"))
0493 self.assertRaisesRegexp(
0494 AnalysisException,
0495 "does_not_exist",
0496 lambda: spark.catalog.isCached("does_not_exist"))
0497 self.assertRaisesRegexp(
0498 AnalysisException,
0499 "does_not_exist",
0500 lambda: spark.catalog.cacheTable("does_not_exist"))
0501 self.assertRaisesRegexp(
0502 AnalysisException,
0503 "does_not_exist",
0504 lambda: spark.catalog.uncacheTable("does_not_exist"))
0505
0506 def _to_pandas(self):
0507 from datetime import datetime, date
0508 schema = StructType().add("a", IntegerType()).add("b", StringType())\
0509 .add("c", BooleanType()).add("d", FloatType())\
0510 .add("dt", DateType()).add("ts", TimestampType())
0511 data = [
0512 (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
0513 (2, "foo", True, 5.0, None, None),
0514 (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
0515 (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
0516 ]
0517 df = self.spark.createDataFrame(data, schema)
0518 return df.toPandas()
0519
0520 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0521 def test_to_pandas(self):
0522 import numpy as np
0523 pdf = self._to_pandas()
0524 types = pdf.dtypes
0525 self.assertEquals(types[0], np.int32)
0526 self.assertEquals(types[1], np.object)
0527 self.assertEquals(types[2], np.bool)
0528 self.assertEquals(types[3], np.float32)
0529 self.assertEquals(types[4], np.object)
0530 self.assertEquals(types[5], 'datetime64[ns]')
0531
0532 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0533 def test_to_pandas_with_duplicated_column_names(self):
0534 import numpy as np
0535
0536 sql = "select 1 v, 1 v"
0537 for arrowEnabled in [False, True]:
0538 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
0539 df = self.spark.sql(sql)
0540 pdf = df.toPandas()
0541 types = pdf.dtypes
0542 self.assertEquals(types.iloc[0], np.int32)
0543 self.assertEquals(types.iloc[1], np.int32)
0544
0545 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0546 def test_to_pandas_on_cross_join(self):
0547 import numpy as np
0548
0549 sql = """
0550 select t1.*, t2.* from (
0551 select explode(sequence(1, 3)) v
0552 ) t1 left join (
0553 select explode(sequence(1, 3)) v
0554 ) t2
0555 """
0556 for arrowEnabled in [False, True]:
0557 with self.sql_conf({"spark.sql.crossJoin.enabled": True,
0558 "spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
0559 df = self.spark.sql(sql)
0560 pdf = df.toPandas()
0561 types = pdf.dtypes
0562 self.assertEquals(types.iloc[0], np.int32)
0563 self.assertEquals(types.iloc[1], np.int32)
0564
0565 @unittest.skipIf(have_pandas, "Required Pandas was found.")
0566 def test_to_pandas_required_pandas_not_found(self):
0567 with QuietTest(self.sc):
0568 with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
0569 self._to_pandas()
0570
0571 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0572 def test_to_pandas_avoid_astype(self):
0573 import numpy as np
0574 schema = StructType().add("a", IntegerType()).add("b", StringType())\
0575 .add("c", IntegerType())
0576 data = [(1, "foo", 16777220), (None, "bar", None)]
0577 df = self.spark.createDataFrame(data, schema)
0578 types = df.toPandas().dtypes
0579 self.assertEquals(types[0], np.float64)
0580 self.assertEquals(types[1], np.object)
0581 self.assertEquals(types[2], np.float64)
0582
0583 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0584 def test_to_pandas_from_empty_dataframe(self):
0585 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0586
0587 import numpy as np
0588 sql = """
0589 SELECT CAST(1 AS TINYINT) AS tinyint,
0590 CAST(1 AS SMALLINT) AS smallint,
0591 CAST(1 AS INT) AS int,
0592 CAST(1 AS BIGINT) AS bigint,
0593 CAST(0 AS FLOAT) AS float,
0594 CAST(0 AS DOUBLE) AS double,
0595 CAST(1 AS BOOLEAN) AS boolean,
0596 CAST('foo' AS STRING) AS string,
0597 CAST('2019-01-01' AS TIMESTAMP) AS timestamp
0598 """
0599 dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
0600 dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
0601 self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
0602
0603 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0604 def test_to_pandas_from_null_dataframe(self):
0605 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0606
0607 import numpy as np
0608 sql = """
0609 SELECT CAST(NULL AS TINYINT) AS tinyint,
0610 CAST(NULL AS SMALLINT) AS smallint,
0611 CAST(NULL AS INT) AS int,
0612 CAST(NULL AS BIGINT) AS bigint,
0613 CAST(NULL AS FLOAT) AS float,
0614 CAST(NULL AS DOUBLE) AS double,
0615 CAST(NULL AS BOOLEAN) AS boolean,
0616 CAST(NULL AS STRING) AS string,
0617 CAST(NULL AS TIMESTAMP) AS timestamp
0618 """
0619 pdf = self.spark.sql(sql).toPandas()
0620 types = pdf.dtypes
0621 self.assertEqual(types[0], np.float64)
0622 self.assertEqual(types[1], np.float64)
0623 self.assertEqual(types[2], np.float64)
0624 self.assertEqual(types[3], np.float64)
0625 self.assertEqual(types[4], np.float32)
0626 self.assertEqual(types[5], np.float64)
0627 self.assertEqual(types[6], np.object)
0628 self.assertEqual(types[7], np.object)
0629 self.assertTrue(np.can_cast(np.datetime64, types[8]))
0630
0631 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0632 def test_to_pandas_from_mixed_dataframe(self):
0633 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0634
0635 import numpy as np
0636 sql = """
0637 SELECT CAST(col1 AS TINYINT) AS tinyint,
0638 CAST(col2 AS SMALLINT) AS smallint,
0639 CAST(col3 AS INT) AS int,
0640 CAST(col4 AS BIGINT) AS bigint,
0641 CAST(col5 AS FLOAT) AS float,
0642 CAST(col6 AS DOUBLE) AS double,
0643 CAST(col7 AS BOOLEAN) AS boolean,
0644 CAST(col8 AS STRING) AS string,
0645 CAST(col9 AS TIMESTAMP) AS timestamp
0646 FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1),
0647 (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
0648 """
0649 pdf_with_some_nulls = self.spark.sql(sql).toPandas()
0650 pdf_with_only_nulls = self.spark.sql(sql).filter('tinyint is null').toPandas()
0651 self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
0652
0653 def test_create_dataframe_from_array_of_long(self):
0654 import array
0655 data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]
0656 df = self.spark.createDataFrame(data)
0657 self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
0658
0659 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0660 def test_create_dataframe_from_pandas_with_timestamp(self):
0661 import pandas as pd
0662 from datetime import datetime
0663 pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
0664 "d": [pd.Timestamp.now().date()]}, columns=["d", "ts"])
0665
0666 df = self.spark.createDataFrame(pdf)
0667 self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
0668 self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
0669
0670 df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
0671 self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
0672 self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
0673
0674 @unittest.skipIf(have_pandas, "Required Pandas was found.")
0675 def test_create_dataframe_required_pandas_not_found(self):
0676 with QuietTest(self.sc):
0677 with self.assertRaisesRegexp(
0678 ImportError,
0679 "(Pandas >= .* must be installed|No module named '?pandas'?)"):
0680 import pandas as pd
0681 from datetime import datetime
0682 pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
0683 "d": [pd.Timestamp.now().date()]})
0684 self.spark.createDataFrame(pdf)
0685
0686
0687 @unittest.skipIf(not have_pandas, pandas_requirement_message)
0688 def test_create_dataframe_from_pandas_with_dst(self):
0689 import pandas as pd
0690 from pandas.util.testing import assert_frame_equal
0691 from datetime import datetime
0692
0693 pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
0694
0695 df = self.spark.createDataFrame(pdf)
0696 assert_frame_equal(pdf, df.toPandas())
0697
0698 orig_env_tz = os.environ.get('TZ', None)
0699 try:
0700 tz = 'America/Los_Angeles'
0701 os.environ['TZ'] = tz
0702 time.tzset()
0703 with self.sql_conf({'spark.sql.session.timeZone': tz}):
0704 df = self.spark.createDataFrame(pdf)
0705 assert_frame_equal(pdf, df.toPandas())
0706 finally:
0707 del os.environ['TZ']
0708 if orig_env_tz is not None:
0709 os.environ['TZ'] = orig_env_tz
0710 time.tzset()
0711
0712 def test_repr_behaviors(self):
0713 import re
0714 pattern = re.compile(r'^ *\|', re.MULTILINE)
0715 df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
0716
0717
0718 with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
0719 expected1 = """+-----+-----+
0720 || key|value|
0721 |+-----+-----+
0722 || 1| 1|
0723 ||22222|22222|
0724 |+-----+-----+
0725 |"""
0726 self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
0727 with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
0728 expected2 = """+---+-----+
0729 ||key|value|
0730 |+---+-----+
0731 || 1| 1|
0732 ||222| 222|
0733 |+---+-----+
0734 |"""
0735 self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
0736 with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
0737 expected3 = """+---+-----+
0738 ||key|value|
0739 |+---+-----+
0740 || 1| 1|
0741 |+---+-----+
0742 |only showing top 1 row
0743 |"""
0744 self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
0745
0746
0747 with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
0748 expected1 = """<table border='1'>
0749 |<tr><th>key</th><th>value</th></tr>
0750 |<tr><td>1</td><td>1</td></tr>
0751 |<tr><td>22222</td><td>22222</td></tr>
0752 |</table>
0753 |"""
0754 self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
0755 with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
0756 expected2 = """<table border='1'>
0757 |<tr><th>key</th><th>value</th></tr>
0758 |<tr><td>1</td><td>1</td></tr>
0759 |<tr><td>222</td><td>222</td></tr>
0760 |</table>
0761 |"""
0762 self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
0763 with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
0764 expected3 = """<table border='1'>
0765 |<tr><th>key</th><th>value</th></tr>
0766 |<tr><td>1</td><td>1</td></tr>
0767 |</table>
0768 |only showing top 1 row
0769 |"""
0770 self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
0771
0772
0773 with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
0774 expected = "DataFrame[key: bigint, value: string]"
0775 self.assertEquals(None, df._repr_html_())
0776 self.assertEquals(expected, df.__repr__())
0777 with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
0778 self.assertEquals(None, df._repr_html_())
0779 self.assertEquals(expected, df.__repr__())
0780 with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
0781 self.assertEquals(None, df._repr_html_())
0782 self.assertEquals(expected, df.__repr__())
0783
0784 def test_to_local_iterator(self):
0785 df = self.spark.range(8, numPartitions=4)
0786 expected = df.collect()
0787 it = df.toLocalIterator()
0788 self.assertEqual(expected, list(it))
0789
0790
0791 df = self.spark.range(3, numPartitions=4)
0792 it = df.toLocalIterator()
0793 expected = df.collect()
0794 self.assertEqual(expected, list(it))
0795
0796 def test_to_local_iterator_prefetch(self):
0797 df = self.spark.range(8, numPartitions=4)
0798 expected = df.collect()
0799 it = df.toLocalIterator(prefetchPartitions=True)
0800 self.assertEqual(expected, list(it))
0801
0802 def test_to_local_iterator_not_fully_consumed(self):
0803
0804
0805 df = self.spark.range(1 << 20, numPartitions=2)
0806 it = df.toLocalIterator()
0807 self.assertEqual(df.take(1)[0], next(it))
0808 with QuietTest(self.sc):
0809 it = None
0810
0811 result = []
0812 for i, row in enumerate(df.toLocalIterator()):
0813 result.append(row)
0814 if i == 7:
0815 break
0816 self.assertEqual(df.take(8), result)
0817
0818
0819 class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
0820
0821
0822
0823 @classmethod
0824 def setUpClass(cls):
0825 import glob
0826 from pyspark.find_spark_home import _find_spark_home
0827
0828 SPARK_HOME = _find_spark_home()
0829 filename_pattern = (
0830 "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
0831 "TestQueryExecutionListener.class")
0832 cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))
0833
0834 if cls.has_listener:
0835
0836 cls.spark = SparkSession.builder \
0837 .master("local[4]") \
0838 .appName(cls.__name__) \
0839 .config(
0840 "spark.sql.queryExecutionListeners",
0841 "org.apache.spark.sql.TestQueryExecutionListener") \
0842 .getOrCreate()
0843
0844 def setUp(self):
0845 if not self.has_listener:
0846 raise self.skipTest(
0847 "'org.apache.spark.sql.TestQueryExecutionListener' is not "
0848 "available. Will skip the related tests.")
0849
0850 @classmethod
0851 def tearDownClass(cls):
0852 if hasattr(cls, "spark"):
0853 cls.spark.stop()
0854
0855 def tearDown(self):
0856 self.spark._jvm.OnSuccessCall.clear()
0857
0858 def test_query_execution_listener_on_collect(self):
0859 self.assertFalse(
0860 self.spark._jvm.OnSuccessCall.isCalled(),
0861 "The callback from the query execution listener should not be called before 'collect'")
0862 self.spark.sql("SELECT * FROM range(1)").collect()
0863 self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
0864 self.assertTrue(
0865 self.spark._jvm.OnSuccessCall.isCalled(),
0866 "The callback from the query execution listener should be called after 'collect'")
0867
0868 @unittest.skipIf(
0869 not have_pandas or not have_pyarrow,
0870 pandas_requirement_message or pyarrow_requirement_message)
0871 def test_query_execution_listener_on_collect_with_arrow(self):
0872 with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
0873 self.assertFalse(
0874 self.spark._jvm.OnSuccessCall.isCalled(),
0875 "The callback from the query execution listener should not be "
0876 "called before 'toPandas'")
0877 self.spark.sql("SELECT * FROM range(1)").toPandas()
0878 self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
0879 self.assertTrue(
0880 self.spark._jvm.OnSuccessCall.isCalled(),
0881 "The callback from the query execution listener should be called after 'toPandas'")
0882
0883
0884 if __name__ == "__main__":
0885 from pyspark.sql.tests.test_dataframe import *
0886
0887 try:
0888 import xmlrunner
0889 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0890 except ImportError:
0891 testRunner = None
0892 unittest.main(testRunner=testRunner, verbosity=2)