Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import pydoc
0020 import time
0021 import unittest
0022 
0023 from pyspark.sql import SparkSession, Row
0024 from pyspark.sql.types import *
0025 from pyspark.sql.utils import AnalysisException, IllegalArgumentException
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \
0027     pandas_requirement_message, pyarrow_requirement_message
0028 from pyspark.testing.utils import QuietTest
0029 
0030 
0031 class DataFrameTests(ReusedSQLTestCase):
0032 
0033     def test_range(self):
0034         self.assertEqual(self.spark.range(1, 1).count(), 0)
0035         self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
0036         self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
0037         self.assertEqual(self.spark.range(-2).count(), 0)
0038         self.assertEqual(self.spark.range(3).count(), 3)
0039 
0040     def test_duplicated_column_names(self):
0041         df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
0042         row = df.select('*').first()
0043         self.assertEqual(1, row[0])
0044         self.assertEqual(2, row[1])
0045         self.assertEqual("Row(c=1, c=2)", str(row))
0046         # Cannot access columns
0047         self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
0048         self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
0049         self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
0050 
0051     def test_freqItems(self):
0052         vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
0053         df = self.sc.parallelize(vals).toDF()
0054         items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
0055         self.assertTrue(1 in items[0])
0056         self.assertTrue(-2.0 in items[1])
0057 
0058     def test_help_command(self):
0059         # Regression test for SPARK-5464
0060         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
0061         df = self.spark.read.json(rdd)
0062         # render_doc() reproduces the help() exception without printing output
0063         pydoc.render_doc(df)
0064         pydoc.render_doc(df.foo)
0065         pydoc.render_doc(df.take(1))
0066 
0067     def test_dropna(self):
0068         schema = StructType([
0069             StructField("name", StringType(), True),
0070             StructField("age", IntegerType(), True),
0071             StructField("height", DoubleType(), True)])
0072 
0073         # shouldn't drop a non-null row
0074         self.assertEqual(self.spark.createDataFrame(
0075             [(u'Alice', 50, 80.1)], schema).dropna().count(),
0076             1)
0077 
0078         # dropping rows with a single null value
0079         self.assertEqual(self.spark.createDataFrame(
0080             [(u'Alice', None, 80.1)], schema).dropna().count(),
0081             0)
0082         self.assertEqual(self.spark.createDataFrame(
0083             [(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
0084             0)
0085 
0086         # if how = 'all', only drop rows if all values are null
0087         self.assertEqual(self.spark.createDataFrame(
0088             [(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
0089             1)
0090         self.assertEqual(self.spark.createDataFrame(
0091             [(None, None, None)], schema).dropna(how='all').count(),
0092             0)
0093 
0094         # how and subset
0095         self.assertEqual(self.spark.createDataFrame(
0096             [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
0097             1)
0098         self.assertEqual(self.spark.createDataFrame(
0099             [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
0100             0)
0101 
0102         # threshold
0103         self.assertEqual(self.spark.createDataFrame(
0104             [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
0105             1)
0106         self.assertEqual(self.spark.createDataFrame(
0107             [(u'Alice', None, None)], schema).dropna(thresh=2).count(),
0108             0)
0109 
0110         # threshold and subset
0111         self.assertEqual(self.spark.createDataFrame(
0112             [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
0113             1)
0114         self.assertEqual(self.spark.createDataFrame(
0115             [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
0116             0)
0117 
0118         # thresh should take precedence over how
0119         self.assertEqual(self.spark.createDataFrame(
0120             [(u'Alice', 50, None)], schema).dropna(
0121                 how='any', thresh=2, subset=['name', 'age']).count(),
0122             1)
0123 
0124     def test_fillna(self):
0125         schema = StructType([
0126             StructField("name", StringType(), True),
0127             StructField("age", IntegerType(), True),
0128             StructField("height", DoubleType(), True),
0129             StructField("spy", BooleanType(), True)])
0130 
0131         # fillna shouldn't change non-null values
0132         row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
0133         self.assertEqual(row.age, 10)
0134 
0135         # fillna with int
0136         row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
0137         self.assertEqual(row.age, 50)
0138         self.assertEqual(row.height, 50.0)
0139 
0140         # fillna with double
0141         row = self.spark.createDataFrame(
0142             [(u'Alice', None, None, None)], schema).fillna(50.1).first()
0143         self.assertEqual(row.age, 50)
0144         self.assertEqual(row.height, 50.1)
0145 
0146         # fillna with bool
0147         row = self.spark.createDataFrame(
0148             [(u'Alice', None, None, None)], schema).fillna(True).first()
0149         self.assertEqual(row.age, None)
0150         self.assertEqual(row.spy, True)
0151 
0152         # fillna with string
0153         row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
0154         self.assertEqual(row.name, u"hello")
0155         self.assertEqual(row.age, None)
0156 
0157         # fillna with subset specified for numeric cols
0158         row = self.spark.createDataFrame(
0159             [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
0160         self.assertEqual(row.name, None)
0161         self.assertEqual(row.age, 50)
0162         self.assertEqual(row.height, None)
0163         self.assertEqual(row.spy, None)
0164 
0165         # fillna with subset specified for string cols
0166         row = self.spark.createDataFrame(
0167             [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
0168         self.assertEqual(row.name, "haha")
0169         self.assertEqual(row.age, None)
0170         self.assertEqual(row.height, None)
0171         self.assertEqual(row.spy, None)
0172 
0173         # fillna with subset specified for bool cols
0174         row = self.spark.createDataFrame(
0175             [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
0176         self.assertEqual(row.name, None)
0177         self.assertEqual(row.age, None)
0178         self.assertEqual(row.height, None)
0179         self.assertEqual(row.spy, True)
0180 
0181         # fillna with dictionary for boolean types
0182         row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
0183         self.assertEqual(row.a, True)
0184 
0185     def test_repartitionByRange_dataframe(self):
0186         schema = StructType([
0187             StructField("name", StringType(), True),
0188             StructField("age", IntegerType(), True),
0189             StructField("height", DoubleType(), True)])
0190 
0191         df1 = self.spark.createDataFrame(
0192             [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
0193         df2 = self.spark.createDataFrame(
0194             [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)
0195 
0196         # test repartitionByRange(numPartitions, *cols)
0197         df3 = df1.repartitionByRange(2, "name", "age")
0198         self.assertEqual(df3.rdd.getNumPartitions(), 2)
0199         self.assertEqual(df3.rdd.first(), df2.rdd.first())
0200         self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
0201 
0202         # test repartitionByRange(numPartitions, *cols)
0203         df4 = df1.repartitionByRange(3, "name", "age")
0204         self.assertEqual(df4.rdd.getNumPartitions(), 3)
0205         self.assertEqual(df4.rdd.first(), df2.rdd.first())
0206         self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
0207 
0208         # test repartitionByRange(*cols)
0209         df5 = df1.repartitionByRange("name", "age")
0210         self.assertEqual(df5.rdd.first(), df2.rdd.first())
0211         self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
0212 
0213     def test_replace(self):
0214         schema = StructType([
0215             StructField("name", StringType(), True),
0216             StructField("age", IntegerType(), True),
0217             StructField("height", DoubleType(), True)])
0218 
0219         # replace with int
0220         row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
0221         self.assertEqual(row.age, 20)
0222         self.assertEqual(row.height, 20.0)
0223 
0224         # replace with double
0225         row = self.spark.createDataFrame(
0226             [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
0227         self.assertEqual(row.age, 82)
0228         self.assertEqual(row.height, 82.1)
0229 
0230         # replace with string
0231         row = self.spark.createDataFrame(
0232             [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
0233         self.assertEqual(row.name, u"Ann")
0234         self.assertEqual(row.age, 10)
0235 
0236         # replace with subset specified by a string of a column name w/ actual change
0237         row = self.spark.createDataFrame(
0238             [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
0239         self.assertEqual(row.age, 20)
0240 
0241         # replace with subset specified by a string of a column name w/o actual change
0242         row = self.spark.createDataFrame(
0243             [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
0244         self.assertEqual(row.age, 10)
0245 
0246         # replace with subset specified with one column replaced, another column not in subset
0247         # stays unchanged.
0248         row = self.spark.createDataFrame(
0249             [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
0250         self.assertEqual(row.name, u'Alice')
0251         self.assertEqual(row.age, 20)
0252         self.assertEqual(row.height, 10.0)
0253 
0254         # replace with subset specified but no column will be replaced
0255         row = self.spark.createDataFrame(
0256             [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
0257         self.assertEqual(row.name, u'Alice')
0258         self.assertEqual(row.age, 10)
0259         self.assertEqual(row.height, None)
0260 
0261         # replace with lists
0262         row = self.spark.createDataFrame(
0263             [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
0264         self.assertTupleEqual(row, (u'Ann', 10, 80.1))
0265 
0266         # replace with dict
0267         row = self.spark.createDataFrame(
0268             [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
0269         self.assertTupleEqual(row, (u'Alice', 11, 80.1))
0270 
0271         # test backward compatibility with dummy value
0272         dummy_value = 1
0273         row = self.spark.createDataFrame(
0274             [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
0275         self.assertTupleEqual(row, (u'Bob', 10, 80.1))
0276 
0277         # test dict with mixed numerics
0278         row = self.spark.createDataFrame(
0279             [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
0280         self.assertTupleEqual(row, (u'Alice', -10, 90.5))
0281 
0282         # replace with tuples
0283         row = self.spark.createDataFrame(
0284             [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
0285         self.assertTupleEqual(row, (u'Bob', 10, 80.1))
0286 
0287         # replace multiple columns
0288         row = self.spark.createDataFrame(
0289             [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
0290         self.assertTupleEqual(row, (u'Alice', 20, 90.0))
0291 
0292         # test for mixed numerics
0293         row = self.spark.createDataFrame(
0294             [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
0295         self.assertTupleEqual(row, (u'Alice', 20, 90.5))
0296 
0297         row = self.spark.createDataFrame(
0298             [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
0299         self.assertTupleEqual(row, (u'Alice', 20, 90.5))
0300 
0301         # replace with boolean
0302         row = (self
0303                .spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
0304                .selectExpr("name = 'Bob'", 'age <= 15')
0305                .replace(False, True).first())
0306         self.assertTupleEqual(row, (True, True))
0307 
0308         # replace string with None and then drop None rows
0309         row = self.spark.createDataFrame(
0310             [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
0311         self.assertEqual(row.count(), 0)
0312 
0313         # replace with number and None
0314         row = self.spark.createDataFrame(
0315             [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
0316         self.assertTupleEqual(row, (u'Alice', 20, None))
0317 
0318         # should fail if subset is not list, tuple or None
0319         with self.assertRaises(ValueError):
0320             self.spark.createDataFrame(
0321                 [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
0322 
0323         # should fail if to_replace and value have different length
0324         with self.assertRaises(ValueError):
0325             self.spark.createDataFrame(
0326                 [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
0327 
0328         # should fail if when received unexpected type
0329         with self.assertRaises(ValueError):
0330             from datetime import datetime
0331             self.spark.createDataFrame(
0332                 [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
0333 
0334         # should fail if provided mixed type replacements
0335         with self.assertRaises(ValueError):
0336             self.spark.createDataFrame(
0337                 [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
0338 
0339         with self.assertRaises(ValueError):
0340             self.spark.createDataFrame(
0341                 [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
0342 
0343         with self.assertRaisesRegexp(
0344                 TypeError,
0345                 'value argument is required when to_replace is not a dictionary.'):
0346             self.spark.createDataFrame(
0347                 [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
0348 
0349     def test_with_column_with_existing_name(self):
0350         keys = self.df.withColumn("key", self.df.key).select("key").collect()
0351         self.assertEqual([r.key for r in keys], list(range(100)))
0352 
0353     # regression test for SPARK-10417
0354     def test_column_iterator(self):
0355 
0356         def foo():
0357             for x in self.df.key:
0358                 break
0359 
0360         self.assertRaises(TypeError, foo)
0361 
0362     def test_generic_hints(self):
0363         from pyspark.sql import DataFrame
0364 
0365         df1 = self.spark.range(10e10).toDF("id")
0366         df2 = self.spark.range(10e10).toDF("id")
0367 
0368         self.assertIsInstance(df1.hint("broadcast"), DataFrame)
0369         self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
0370 
0371         # Dummy rules
0372         self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
0373         self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
0374 
0375         plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
0376         self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
0377 
0378     # add tests for SPARK-23647 (test more types for hint)
0379     def test_extended_hint_types(self):
0380         from pyspark.sql import DataFrame
0381 
0382         df = self.spark.range(10e10).toDF("id")
0383         such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
0384         hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list)
0385         logical_plan = hinted_df._jdf.queryExecution().logical()
0386 
0387         self.assertEqual(1, logical_plan.toString().count("1.2345"))
0388         self.assertEqual(1, logical_plan.toString().count("what"))
0389         self.assertEqual(3, logical_plan.toString().count("itworks"))
0390 
0391     def test_sample(self):
0392         self.assertRaisesRegexp(
0393             TypeError,
0394             "should be a bool, float and number",
0395             lambda: self.spark.range(1).sample())
0396 
0397         self.assertRaises(
0398             TypeError,
0399             lambda: self.spark.range(1).sample("a"))
0400 
0401         self.assertRaises(
0402             TypeError,
0403             lambda: self.spark.range(1).sample(seed="abc"))
0404 
0405         self.assertRaises(
0406             IllegalArgumentException,
0407             lambda: self.spark.range(1).sample(-1.0))
0408 
0409     def test_toDF_with_schema_string(self):
0410         data = [Row(key=i, value=str(i)) for i in range(100)]
0411         rdd = self.sc.parallelize(data, 5)
0412 
0413         df = rdd.toDF("key: int, value: string")
0414         self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
0415         self.assertEqual(df.collect(), data)
0416 
0417         # different but compatible field types can be used.
0418         df = rdd.toDF("key: string, value: string")
0419         self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
0420         self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
0421 
0422         # field names can differ.
0423         df = rdd.toDF(" a: int, b: string ")
0424         self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
0425         self.assertEqual(df.collect(), data)
0426 
0427         # number of fields must match.
0428         self.assertRaisesRegexp(Exception, "Length of object",
0429                                 lambda: rdd.toDF("key: int").collect())
0430 
0431         # field types mismatch will cause exception at runtime.
0432         self.assertRaisesRegexp(Exception, "FloatType can not accept",
0433                                 lambda: rdd.toDF("key: float, value: string").collect())
0434 
0435         # flat schema values will be wrapped into row.
0436         df = rdd.map(lambda row: row.key).toDF("int")
0437         self.assertEqual(df.schema.simpleString(), "struct<value:int>")
0438         self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
0439 
0440         # users can use DataType directly instead of data type string.
0441         df = rdd.map(lambda row: row.key).toDF(IntegerType())
0442         self.assertEqual(df.schema.simpleString(), "struct<value:int>")
0443         self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
0444 
0445     def test_join_without_on(self):
0446         df1 = self.spark.range(1).toDF("a")
0447         df2 = self.spark.range(1).toDF("b")
0448 
0449         with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
0450             self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
0451 
0452         with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
0453             actual = df1.join(df2, how="inner").collect()
0454             expected = [Row(a=0, b=0)]
0455             self.assertEqual(actual, expected)
0456 
0457     # Regression test for invalid join methods when on is None, Spark-14761
0458     def test_invalid_join_method(self):
0459         df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
0460         df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
0461         self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
0462 
0463     # Cartesian products require cross join syntax
0464     def test_require_cross(self):
0465 
0466         df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
0467         df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
0468 
0469         with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
0470             # joins without conditions require cross join syntax
0471             self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
0472 
0473             # works with crossJoin
0474             self.assertEqual(1, df1.crossJoin(df2).count())
0475 
0476     def test_cache(self):
0477         spark = self.spark
0478         with self.tempView("tab1", "tab2"):
0479             spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
0480             spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
0481             self.assertFalse(spark.catalog.isCached("tab1"))
0482             self.assertFalse(spark.catalog.isCached("tab2"))
0483             spark.catalog.cacheTable("tab1")
0484             self.assertTrue(spark.catalog.isCached("tab1"))
0485             self.assertFalse(spark.catalog.isCached("tab2"))
0486             spark.catalog.cacheTable("tab2")
0487             spark.catalog.uncacheTable("tab1")
0488             self.assertFalse(spark.catalog.isCached("tab1"))
0489             self.assertTrue(spark.catalog.isCached("tab2"))
0490             spark.catalog.clearCache()
0491             self.assertFalse(spark.catalog.isCached("tab1"))
0492             self.assertFalse(spark.catalog.isCached("tab2"))
0493             self.assertRaisesRegexp(
0494                 AnalysisException,
0495                 "does_not_exist",
0496                 lambda: spark.catalog.isCached("does_not_exist"))
0497             self.assertRaisesRegexp(
0498                 AnalysisException,
0499                 "does_not_exist",
0500                 lambda: spark.catalog.cacheTable("does_not_exist"))
0501             self.assertRaisesRegexp(
0502                 AnalysisException,
0503                 "does_not_exist",
0504                 lambda: spark.catalog.uncacheTable("does_not_exist"))
0505 
0506     def _to_pandas(self):
0507         from datetime import datetime, date
0508         schema = StructType().add("a", IntegerType()).add("b", StringType())\
0509                              .add("c", BooleanType()).add("d", FloatType())\
0510                              .add("dt", DateType()).add("ts", TimestampType())
0511         data = [
0512             (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
0513             (2, "foo", True, 5.0, None, None),
0514             (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
0515             (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
0516         ]
0517         df = self.spark.createDataFrame(data, schema)
0518         return df.toPandas()
0519 
0520     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0521     def test_to_pandas(self):
0522         import numpy as np
0523         pdf = self._to_pandas()
0524         types = pdf.dtypes
0525         self.assertEquals(types[0], np.int32)
0526         self.assertEquals(types[1], np.object)
0527         self.assertEquals(types[2], np.bool)
0528         self.assertEquals(types[3], np.float32)
0529         self.assertEquals(types[4], np.object)  # datetime.date
0530         self.assertEquals(types[5], 'datetime64[ns]')
0531 
0532     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0533     def test_to_pandas_with_duplicated_column_names(self):
0534         import numpy as np
0535 
0536         sql = "select 1 v, 1 v"
0537         for arrowEnabled in [False, True]:
0538             with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
0539                 df = self.spark.sql(sql)
0540                 pdf = df.toPandas()
0541                 types = pdf.dtypes
0542                 self.assertEquals(types.iloc[0], np.int32)
0543                 self.assertEquals(types.iloc[1], np.int32)
0544 
0545     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0546     def test_to_pandas_on_cross_join(self):
0547         import numpy as np
0548 
0549         sql = """
0550         select t1.*, t2.* from (
0551           select explode(sequence(1, 3)) v
0552         ) t1 left join (
0553           select explode(sequence(1, 3)) v
0554         ) t2
0555         """
0556         for arrowEnabled in [False, True]:
0557             with self.sql_conf({"spark.sql.crossJoin.enabled": True,
0558                                 "spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
0559                 df = self.spark.sql(sql)
0560                 pdf = df.toPandas()
0561                 types = pdf.dtypes
0562                 self.assertEquals(types.iloc[0], np.int32)
0563                 self.assertEquals(types.iloc[1], np.int32)
0564 
0565     @unittest.skipIf(have_pandas, "Required Pandas was found.")
0566     def test_to_pandas_required_pandas_not_found(self):
0567         with QuietTest(self.sc):
0568             with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
0569                 self._to_pandas()
0570 
0571     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0572     def test_to_pandas_avoid_astype(self):
0573         import numpy as np
0574         schema = StructType().add("a", IntegerType()).add("b", StringType())\
0575                              .add("c", IntegerType())
0576         data = [(1, "foo", 16777220), (None, "bar", None)]
0577         df = self.spark.createDataFrame(data, schema)
0578         types = df.toPandas().dtypes
0579         self.assertEquals(types[0], np.float64)  # doesn't convert to np.int32 due to NaN value.
0580         self.assertEquals(types[1], np.object)
0581         self.assertEquals(types[2], np.float64)
0582 
0583     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0584     def test_to_pandas_from_empty_dataframe(self):
0585         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0586             # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes
0587             import numpy as np
0588             sql = """
0589             SELECT CAST(1 AS TINYINT) AS tinyint,
0590             CAST(1 AS SMALLINT) AS smallint,
0591             CAST(1 AS INT) AS int,
0592             CAST(1 AS BIGINT) AS bigint,
0593             CAST(0 AS FLOAT) AS float,
0594             CAST(0 AS DOUBLE) AS double,
0595             CAST(1 AS BOOLEAN) AS boolean,
0596             CAST('foo' AS STRING) AS string,
0597             CAST('2019-01-01' AS TIMESTAMP) AS timestamp
0598             """
0599             dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
0600             dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
0601             self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
0602 
0603     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0604     def test_to_pandas_from_null_dataframe(self):
0605         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0606             # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes
0607             import numpy as np
0608             sql = """
0609             SELECT CAST(NULL AS TINYINT) AS tinyint,
0610             CAST(NULL AS SMALLINT) AS smallint,
0611             CAST(NULL AS INT) AS int,
0612             CAST(NULL AS BIGINT) AS bigint,
0613             CAST(NULL AS FLOAT) AS float,
0614             CAST(NULL AS DOUBLE) AS double,
0615             CAST(NULL AS BOOLEAN) AS boolean,
0616             CAST(NULL AS STRING) AS string,
0617             CAST(NULL AS TIMESTAMP) AS timestamp
0618             """
0619             pdf = self.spark.sql(sql).toPandas()
0620             types = pdf.dtypes
0621             self.assertEqual(types[0], np.float64)
0622             self.assertEqual(types[1], np.float64)
0623             self.assertEqual(types[2], np.float64)
0624             self.assertEqual(types[3], np.float64)
0625             self.assertEqual(types[4], np.float32)
0626             self.assertEqual(types[5], np.float64)
0627             self.assertEqual(types[6], np.object)
0628             self.assertEqual(types[7], np.object)
0629             self.assertTrue(np.can_cast(np.datetime64, types[8]))
0630 
0631     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0632     def test_to_pandas_from_mixed_dataframe(self):
0633         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
0634             # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes
0635             import numpy as np
0636             sql = """
0637             SELECT CAST(col1 AS TINYINT) AS tinyint,
0638             CAST(col2 AS SMALLINT) AS smallint,
0639             CAST(col3 AS INT) AS int,
0640             CAST(col4 AS BIGINT) AS bigint,
0641             CAST(col5 AS FLOAT) AS float,
0642             CAST(col6 AS DOUBLE) AS double,
0643             CAST(col7 AS BOOLEAN) AS boolean,
0644             CAST(col8 AS STRING) AS string,
0645             CAST(col9 AS TIMESTAMP) AS timestamp
0646             FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1),
0647                         (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
0648             """
0649             pdf_with_some_nulls = self.spark.sql(sql).toPandas()
0650             pdf_with_only_nulls = self.spark.sql(sql).filter('tinyint is null').toPandas()
0651             self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
0652 
0653     def test_create_dataframe_from_array_of_long(self):
0654         import array
0655         data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]
0656         df = self.spark.createDataFrame(data)
0657         self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
0658 
0659     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0660     def test_create_dataframe_from_pandas_with_timestamp(self):
0661         import pandas as pd
0662         from datetime import datetime
0663         pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
0664                             "d": [pd.Timestamp.now().date()]}, columns=["d", "ts"])
0665         # test types are inferred correctly without specifying schema
0666         df = self.spark.createDataFrame(pdf)
0667         self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
0668         self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
0669         # test with schema will accept pdf as input
0670         df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
0671         self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
0672         self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
0673 
0674     @unittest.skipIf(have_pandas, "Required Pandas was found.")
0675     def test_create_dataframe_required_pandas_not_found(self):
0676         with QuietTest(self.sc):
0677             with self.assertRaisesRegexp(
0678                     ImportError,
0679                     "(Pandas >= .* must be installed|No module named '?pandas'?)"):
0680                 import pandas as pd
0681                 from datetime import datetime
0682                 pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
0683                                     "d": [pd.Timestamp.now().date()]})
0684                 self.spark.createDataFrame(pdf)
0685 
0686     # Regression test for SPARK-23360
0687     @unittest.skipIf(not have_pandas, pandas_requirement_message)
0688     def test_create_dataframe_from_pandas_with_dst(self):
0689         import pandas as pd
0690         from pandas.util.testing import assert_frame_equal
0691         from datetime import datetime
0692 
0693         pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
0694 
0695         df = self.spark.createDataFrame(pdf)
0696         assert_frame_equal(pdf, df.toPandas())
0697 
0698         orig_env_tz = os.environ.get('TZ', None)
0699         try:
0700             tz = 'America/Los_Angeles'
0701             os.environ['TZ'] = tz
0702             time.tzset()
0703             with self.sql_conf({'spark.sql.session.timeZone': tz}):
0704                 df = self.spark.createDataFrame(pdf)
0705                 assert_frame_equal(pdf, df.toPandas())
0706         finally:
0707             del os.environ['TZ']
0708             if orig_env_tz is not None:
0709                 os.environ['TZ'] = orig_env_tz
0710             time.tzset()
0711 
0712     def test_repr_behaviors(self):
0713         import re
0714         pattern = re.compile(r'^ *\|', re.MULTILINE)
0715         df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
0716 
0717         # test when eager evaluation is enabled and _repr_html_ will not be called
0718         with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
0719             expected1 = """+-----+-----+
0720                 ||  key|value|
0721                 |+-----+-----+
0722                 ||    1|    1|
0723                 ||22222|22222|
0724                 |+-----+-----+
0725                 |"""
0726             self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
0727             with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
0728                 expected2 = """+---+-----+
0729                 ||key|value|
0730                 |+---+-----+
0731                 ||  1|    1|
0732                 ||222|  222|
0733                 |+---+-----+
0734                 |"""
0735                 self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
0736                 with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
0737                     expected3 = """+---+-----+
0738                     ||key|value|
0739                     |+---+-----+
0740                     ||  1|    1|
0741                     |+---+-----+
0742                     |only showing top 1 row
0743                     |"""
0744                     self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
0745 
0746         # test when eager evaluation is enabled and _repr_html_ will be called
0747         with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
0748             expected1 = """<table border='1'>
0749                 |<tr><th>key</th><th>value</th></tr>
0750                 |<tr><td>1</td><td>1</td></tr>
0751                 |<tr><td>22222</td><td>22222</td></tr>
0752                 |</table>
0753                 |"""
0754             self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
0755             with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
0756                 expected2 = """<table border='1'>
0757                     |<tr><th>key</th><th>value</th></tr>
0758                     |<tr><td>1</td><td>1</td></tr>
0759                     |<tr><td>222</td><td>222</td></tr>
0760                     |</table>
0761                     |"""
0762                 self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
0763                 with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
0764                     expected3 = """<table border='1'>
0765                         |<tr><th>key</th><th>value</th></tr>
0766                         |<tr><td>1</td><td>1</td></tr>
0767                         |</table>
0768                         |only showing top 1 row
0769                         |"""
0770                     self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
0771 
0772         # test when eager evaluation is disabled and _repr_html_ will be called
0773         with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
0774             expected = "DataFrame[key: bigint, value: string]"
0775             self.assertEquals(None, df._repr_html_())
0776             self.assertEquals(expected, df.__repr__())
0777             with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
0778                 self.assertEquals(None, df._repr_html_())
0779                 self.assertEquals(expected, df.__repr__())
0780                 with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
0781                     self.assertEquals(None, df._repr_html_())
0782                     self.assertEquals(expected, df.__repr__())
0783 
0784     def test_to_local_iterator(self):
0785         df = self.spark.range(8, numPartitions=4)
0786         expected = df.collect()
0787         it = df.toLocalIterator()
0788         self.assertEqual(expected, list(it))
0789 
0790         # Test DataFrame with empty partition
0791         df = self.spark.range(3, numPartitions=4)
0792         it = df.toLocalIterator()
0793         expected = df.collect()
0794         self.assertEqual(expected, list(it))
0795 
0796     def test_to_local_iterator_prefetch(self):
0797         df = self.spark.range(8, numPartitions=4)
0798         expected = df.collect()
0799         it = df.toLocalIterator(prefetchPartitions=True)
0800         self.assertEqual(expected, list(it))
0801 
0802     def test_to_local_iterator_not_fully_consumed(self):
0803         # SPARK-23961: toLocalIterator throws exception when not fully consumed
0804         # Create a DataFrame large enough so that write to socket will eventually block
0805         df = self.spark.range(1 << 20, numPartitions=2)
0806         it = df.toLocalIterator()
0807         self.assertEqual(df.take(1)[0], next(it))
0808         with QuietTest(self.sc):
0809             it = None  # remove iterator from scope, socket is closed when cleaned up
0810             # Make sure normal df operations still work
0811             result = []
0812             for i, row in enumerate(df.toLocalIterator()):
0813                 result.append(row)
0814                 if i == 7:
0815                     break
0816             self.assertEqual(df.take(8), result)
0817 
0818 
0819 class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
0820     # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
0821     # static and immutable. This can't be set or unset, for example, via `spark.conf`.
0822 
0823     @classmethod
0824     def setUpClass(cls):
0825         import glob
0826         from pyspark.find_spark_home import _find_spark_home
0827 
0828         SPARK_HOME = _find_spark_home()
0829         filename_pattern = (
0830             "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
0831             "TestQueryExecutionListener.class")
0832         cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))
0833 
0834         if cls.has_listener:
0835             # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
0836             cls.spark = SparkSession.builder \
0837                 .master("local[4]") \
0838                 .appName(cls.__name__) \
0839                 .config(
0840                     "spark.sql.queryExecutionListeners",
0841                     "org.apache.spark.sql.TestQueryExecutionListener") \
0842                 .getOrCreate()
0843 
0844     def setUp(self):
0845         if not self.has_listener:
0846             raise self.skipTest(
0847                 "'org.apache.spark.sql.TestQueryExecutionListener' is not "
0848                 "available. Will skip the related tests.")
0849 
0850     @classmethod
0851     def tearDownClass(cls):
0852         if hasattr(cls, "spark"):
0853             cls.spark.stop()
0854 
0855     def tearDown(self):
0856         self.spark._jvm.OnSuccessCall.clear()
0857 
0858     def test_query_execution_listener_on_collect(self):
0859         self.assertFalse(
0860             self.spark._jvm.OnSuccessCall.isCalled(),
0861             "The callback from the query execution listener should not be called before 'collect'")
0862         self.spark.sql("SELECT * FROM range(1)").collect()
0863         self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
0864         self.assertTrue(
0865             self.spark._jvm.OnSuccessCall.isCalled(),
0866             "The callback from the query execution listener should be called after 'collect'")
0867 
0868     @unittest.skipIf(
0869         not have_pandas or not have_pyarrow,
0870         pandas_requirement_message or pyarrow_requirement_message)
0871     def test_query_execution_listener_on_collect_with_arrow(self):
0872         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
0873             self.assertFalse(
0874                 self.spark._jvm.OnSuccessCall.isCalled(),
0875                 "The callback from the query execution listener should not be "
0876                 "called before 'toPandas'")
0877             self.spark.sql("SELECT * FROM range(1)").toPandas()
0878             self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
0879             self.assertTrue(
0880                 self.spark._jvm.OnSuccessCall.isCalled(),
0881                 "The callback from the query execution listener should be called after 'toPandas'")
0882 
0883 
0884 if __name__ == "__main__":
0885     from pyspark.sql.tests.test_dataframe import *
0886 
0887     try:
0888         import xmlrunner
0889         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0890     except ImportError:
0891         testRunner = None
0892     unittest.main(testRunner=testRunner, verbosity=2)