0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import shutil
0020 import tempfile
0021
0022 from pyspark.sql.types import *
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024
0025
0026 class ReadwriterTests(ReusedSQLTestCase):
0027
0028 def test_save_and_load(self):
0029 df = self.df
0030 tmpPath = tempfile.mkdtemp()
0031 shutil.rmtree(tmpPath)
0032 df.write.json(tmpPath)
0033 actual = self.spark.read.json(tmpPath)
0034 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0035
0036 schema = StructType([StructField("value", StringType(), True)])
0037 actual = self.spark.read.json(tmpPath, schema)
0038 self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
0039
0040 df.write.json(tmpPath, "overwrite")
0041 actual = self.spark.read.json(tmpPath)
0042 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0043
0044 df.write.save(format="json", mode="overwrite", path=tmpPath,
0045 noUse="this options will not be used in save.")
0046 actual = self.spark.read.load(format="json", path=tmpPath,
0047 noUse="this options will not be used in load.")
0048 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0049
0050 defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
0051 "org.apache.spark.sql.parquet")
0052 self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
0053 actual = self.spark.read.load(path=tmpPath)
0054 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0055 self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
0056
0057 csvpath = os.path.join(tempfile.mkdtemp(), 'data')
0058 df.write.option('quote', None).format('csv').save(csvpath)
0059
0060 shutil.rmtree(tmpPath)
0061
0062 def test_save_and_load_builder(self):
0063 df = self.df
0064 tmpPath = tempfile.mkdtemp()
0065 shutil.rmtree(tmpPath)
0066 df.write.json(tmpPath)
0067 actual = self.spark.read.json(tmpPath)
0068 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0069
0070 schema = StructType([StructField("value", StringType(), True)])
0071 actual = self.spark.read.json(tmpPath, schema)
0072 self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
0073
0074 df.write.mode("overwrite").json(tmpPath)
0075 actual = self.spark.read.json(tmpPath)
0076 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0077
0078 df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
0079 .option("noUse", "this option will not be used in save.")\
0080 .format("json").save(path=tmpPath)
0081 actual =\
0082 self.spark.read.format("json")\
0083 .load(path=tmpPath, noUse="this options will not be used in load.")
0084 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0085
0086 defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
0087 "org.apache.spark.sql.parquet")
0088 self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
0089 actual = self.spark.read.load(path=tmpPath)
0090 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0091 self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
0092
0093 shutil.rmtree(tmpPath)
0094
0095 def test_bucketed_write(self):
0096 data = [
0097 (1, "foo", 3.0), (2, "foo", 5.0),
0098 (3, "bar", -1.0), (4, "bar", 6.0),
0099 ]
0100 df = self.spark.createDataFrame(data, ["x", "y", "z"])
0101
0102 def count_bucketed_cols(names, table="pyspark_bucket"):
0103 """Given a sequence of column names and a table name
0104 query the catalog and return number o columns which are
0105 used for bucketing
0106 """
0107 cols = self.spark.catalog.listColumns(table)
0108 num = len([c for c in cols if c.name in names and c.isBucket])
0109 return num
0110
0111 with self.table("pyspark_bucket"):
0112
0113 df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
0114 self.assertEqual(count_bucketed_cols(["x"]), 1)
0115 self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0116
0117
0118 df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
0119 self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
0120 self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0121
0122
0123 df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
0124 self.assertEqual(count_bucketed_cols(["x"]), 1)
0125 self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0126
0127
0128 df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
0129 self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
0130 self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0131
0132
0133 (df.write.bucketBy(2, "x")
0134 .sortBy(["y", "z"])
0135 .mode("overwrite").saveAsTable("pyspark_bucket"))
0136 self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0137
0138
0139 (df.write.bucketBy(2, "x")
0140 .sortBy("y", "z")
0141 .mode("overwrite").saveAsTable("pyspark_bucket"))
0142 self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0143
0144 def test_insert_into(self):
0145 df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
0146 with self.table("test_table"):
0147 df.write.saveAsTable("test_table")
0148 self.assertEqual(2, self.spark.sql("select * from test_table").count())
0149
0150 df.write.insertInto("test_table")
0151 self.assertEqual(4, self.spark.sql("select * from test_table").count())
0152
0153 df.write.mode("overwrite").insertInto("test_table")
0154 self.assertEqual(2, self.spark.sql("select * from test_table").count())
0155
0156 df.write.insertInto("test_table", True)
0157 self.assertEqual(2, self.spark.sql("select * from test_table").count())
0158
0159 df.write.insertInto("test_table", False)
0160 self.assertEqual(4, self.spark.sql("select * from test_table").count())
0161
0162 df.write.mode("overwrite").insertInto("test_table", False)
0163 self.assertEqual(6, self.spark.sql("select * from test_table").count())
0164
0165
0166 if __name__ == "__main__":
0167 import unittest
0168 from pyspark.sql.tests.test_readwriter import *
0169
0170 try:
0171 import xmlrunner
0172 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0173 except ImportError:
0174 testRunner = None
0175 unittest.main(testRunner=testRunner, verbosity=2)