Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import shutil
0020 import tempfile
0021 
0022 from pyspark.sql.types import *
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024 
0025 
0026 class ReadwriterTests(ReusedSQLTestCase):
0027 
0028     def test_save_and_load(self):
0029         df = self.df
0030         tmpPath = tempfile.mkdtemp()
0031         shutil.rmtree(tmpPath)
0032         df.write.json(tmpPath)
0033         actual = self.spark.read.json(tmpPath)
0034         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0035 
0036         schema = StructType([StructField("value", StringType(), True)])
0037         actual = self.spark.read.json(tmpPath, schema)
0038         self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
0039 
0040         df.write.json(tmpPath, "overwrite")
0041         actual = self.spark.read.json(tmpPath)
0042         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0043 
0044         df.write.save(format="json", mode="overwrite", path=tmpPath,
0045                       noUse="this options will not be used in save.")
0046         actual = self.spark.read.load(format="json", path=tmpPath,
0047                                       noUse="this options will not be used in load.")
0048         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0049 
0050         defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
0051                                                     "org.apache.spark.sql.parquet")
0052         self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
0053         actual = self.spark.read.load(path=tmpPath)
0054         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0055         self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
0056 
0057         csvpath = os.path.join(tempfile.mkdtemp(), 'data')
0058         df.write.option('quote', None).format('csv').save(csvpath)
0059 
0060         shutil.rmtree(tmpPath)
0061 
0062     def test_save_and_load_builder(self):
0063         df = self.df
0064         tmpPath = tempfile.mkdtemp()
0065         shutil.rmtree(tmpPath)
0066         df.write.json(tmpPath)
0067         actual = self.spark.read.json(tmpPath)
0068         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0069 
0070         schema = StructType([StructField("value", StringType(), True)])
0071         actual = self.spark.read.json(tmpPath, schema)
0072         self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
0073 
0074         df.write.mode("overwrite").json(tmpPath)
0075         actual = self.spark.read.json(tmpPath)
0076         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0077 
0078         df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
0079                 .option("noUse", "this option will not be used in save.")\
0080                 .format("json").save(path=tmpPath)
0081         actual =\
0082             self.spark.read.format("json")\
0083                            .load(path=tmpPath, noUse="this options will not be used in load.")
0084         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0085 
0086         defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
0087                                                     "org.apache.spark.sql.parquet")
0088         self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
0089         actual = self.spark.read.load(path=tmpPath)
0090         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0091         self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
0092 
0093         shutil.rmtree(tmpPath)
0094 
0095     def test_bucketed_write(self):
0096         data = [
0097             (1, "foo", 3.0), (2, "foo", 5.0),
0098             (3, "bar", -1.0), (4, "bar", 6.0),
0099         ]
0100         df = self.spark.createDataFrame(data, ["x", "y", "z"])
0101 
0102         def count_bucketed_cols(names, table="pyspark_bucket"):
0103             """Given a sequence of column names and a table name
0104             query the catalog and return number o columns which are
0105             used for bucketing
0106             """
0107             cols = self.spark.catalog.listColumns(table)
0108             num = len([c for c in cols if c.name in names and c.isBucket])
0109             return num
0110 
0111         with self.table("pyspark_bucket"):
0112             # Test write with one bucketing column
0113             df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
0114             self.assertEqual(count_bucketed_cols(["x"]), 1)
0115             self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0116 
0117             # Test write two bucketing columns
0118             df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
0119             self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
0120             self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0121 
0122             # Test write with bucket and sort
0123             df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
0124             self.assertEqual(count_bucketed_cols(["x"]), 1)
0125             self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0126 
0127             # Test write with a list of columns
0128             df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
0129             self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
0130             self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0131 
0132             # Test write with bucket and sort with a list of columns
0133             (df.write.bucketBy(2, "x")
0134                 .sortBy(["y", "z"])
0135                 .mode("overwrite").saveAsTable("pyspark_bucket"))
0136             self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0137 
0138             # Test write with bucket and sort with multiple columns
0139             (df.write.bucketBy(2, "x")
0140                 .sortBy("y", "z")
0141                 .mode("overwrite").saveAsTable("pyspark_bucket"))
0142             self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
0143 
0144     def test_insert_into(self):
0145         df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
0146         with self.table("test_table"):
0147             df.write.saveAsTable("test_table")
0148             self.assertEqual(2, self.spark.sql("select * from test_table").count())
0149 
0150             df.write.insertInto("test_table")
0151             self.assertEqual(4, self.spark.sql("select * from test_table").count())
0152 
0153             df.write.mode("overwrite").insertInto("test_table")
0154             self.assertEqual(2, self.spark.sql("select * from test_table").count())
0155 
0156             df.write.insertInto("test_table", True)
0157             self.assertEqual(2, self.spark.sql("select * from test_table").count())
0158 
0159             df.write.insertInto("test_table", False)
0160             self.assertEqual(4, self.spark.sql("select * from test_table").count())
0161 
0162             df.write.mode("overwrite").insertInto("test_table", False)
0163             self.assertEqual(6, self.spark.sql("select * from test_table").count())
0164 
0165 
0166 if __name__ == "__main__":
0167     import unittest
0168     from pyspark.sql.tests.test_readwriter import *
0169 
0170     try:
0171         import xmlrunner
0172         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0173     except ImportError:
0174         testRunner = None
0175     unittest.main(testRunner=testRunner, verbosity=2)