Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import shutil
0019 import tempfile
0020 
0021 from pyspark.sql import Row
0022 from pyspark.sql.types import *
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024 
0025 
0026 class DataSourcesTests(ReusedSQLTestCase):
0027 
0028     def test_linesep_text(self):
0029         df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
0030         expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
0031                     Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
0032                     Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
0033                     Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
0034         self.assertEqual(df.collect(), expected)
0035 
0036         tpath = tempfile.mkdtemp()
0037         shutil.rmtree(tpath)
0038         try:
0039             df.write.text(tpath, lineSep="!")
0040             expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
0041                         Row(value=u'Tom!30!"My name is Tom"'),
0042                         Row(value=u'Hyukjin!25!"I am Hyukjin'),
0043                         Row(value=u''), Row(value=u'I love Spark!"'),
0044                         Row(value=u'!')]
0045             readback = self.spark.read.text(tpath)
0046             self.assertEqual(readback.collect(), expected)
0047         finally:
0048             shutil.rmtree(tpath)
0049 
0050     def test_multiline_json(self):
0051         people1 = self.spark.read.json("python/test_support/sql/people.json")
0052         people_array = self.spark.read.json("python/test_support/sql/people_array.json",
0053                                             multiLine=True)
0054         self.assertEqual(people1.collect(), people_array.collect())
0055 
0056     def test_encoding_json(self):
0057         people_array = self.spark.read\
0058             .json("python/test_support/sql/people_array_utf16le.json",
0059                   multiLine=True, encoding="UTF-16LE")
0060         expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
0061         self.assertEqual(people_array.collect(), expected)
0062 
0063     def test_linesep_json(self):
0064         df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
0065         expected = [Row(_corrupt_record=None, name=u'Michael'),
0066                     Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
0067                     Row(_corrupt_record=u' "age":19}\n', name=None)]
0068         self.assertEqual(df.collect(), expected)
0069 
0070         tpath = tempfile.mkdtemp()
0071         shutil.rmtree(tpath)
0072         try:
0073             df = self.spark.read.json("python/test_support/sql/people.json")
0074             df.write.json(tpath, lineSep="!!")
0075             readback = self.spark.read.json(tpath, lineSep="!!")
0076             self.assertEqual(readback.collect(), df.collect())
0077         finally:
0078             shutil.rmtree(tpath)
0079 
0080     def test_multiline_csv(self):
0081         ages_newlines = self.spark.read.csv(
0082             "python/test_support/sql/ages_newlines.csv", multiLine=True)
0083         expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
0084                     Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
0085                     Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
0086         self.assertEqual(ages_newlines.collect(), expected)
0087 
0088     def test_ignorewhitespace_csv(self):
0089         tmpPath = tempfile.mkdtemp()
0090         shutil.rmtree(tmpPath)
0091         self.spark.createDataFrame([[" a", "b  ", " c "]]).write.csv(
0092             tmpPath,
0093             ignoreLeadingWhiteSpace=False,
0094             ignoreTrailingWhiteSpace=False)
0095 
0096         expected = [Row(value=u' a,b  , c ')]
0097         readback = self.spark.read.text(tmpPath)
0098         self.assertEqual(readback.collect(), expected)
0099         shutil.rmtree(tmpPath)
0100 
0101     def test_read_multiple_orc_file(self):
0102         df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
0103                                   "python/test_support/sql/orc_partitioned/b=1/c=1"])
0104         self.assertEqual(2, df.count())
0105 
0106     def test_read_text_file_list(self):
0107         df = self.spark.read.text(['python/test_support/sql/text-test.txt',
0108                                    'python/test_support/sql/text-test.txt'])
0109         count = df.count()
0110         self.assertEquals(count, 4)
0111 
0112     def test_json_sampling_ratio(self):
0113         rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
0114             .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
0115         schema = self.spark.read.option('inferSchema', True) \
0116             .option('samplingRatio', 0.5) \
0117             .json(rdd).schema
0118         self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
0119 
0120     def test_csv_sampling_ratio(self):
0121         rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
0122             .map(lambda x: '0.1' if x == 1 else str(x))
0123         schema = self.spark.read.option('inferSchema', True)\
0124             .csv(rdd, samplingRatio=0.5).schema
0125         self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
0126 
0127     def test_checking_csv_header(self):
0128         path = tempfile.mkdtemp()
0129         shutil.rmtree(path)
0130         try:
0131             self.spark.createDataFrame([[1, 1000], [2000, 2]])\
0132                 .toDF('f1', 'f2').write.option("header", "true").csv(path)
0133             schema = StructType([
0134                 StructField('f2', IntegerType(), nullable=True),
0135                 StructField('f1', IntegerType(), nullable=True)])
0136             df = self.spark.read.option('header', 'true').schema(schema)\
0137                 .csv(path, enforceSchema=False)
0138             self.assertRaisesRegexp(
0139                 Exception,
0140                 "CSV header does not conform to the schema",
0141                 lambda: df.collect())
0142         finally:
0143             shutil.rmtree(path)
0144 
0145     def test_ignore_column_of_all_nulls(self):
0146         path = tempfile.mkdtemp()
0147         shutil.rmtree(path)
0148         try:
0149             df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""],
0150                                              ["""{"a":null, "b":null, "c":"string"}"""],
0151                                              ["""{"a":null, "b":null, "c":null}"""]])
0152             df.write.text(path)
0153             schema = StructType([
0154                 StructField('b', LongType(), nullable=True),
0155                 StructField('c', StringType(), nullable=True)])
0156             readback = self.spark.read.json(path, dropFieldIfAllNull=True)
0157             self.assertEquals(readback.schema, schema)
0158         finally:
0159             shutil.rmtree(path)
0160 
0161 
0162 if __name__ == "__main__":
0163     import unittest
0164     from pyspark.sql.tests.test_datasources import *
0165 
0166     try:
0167         import xmlrunner
0168         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0169     except ImportError:
0170         testRunner = None
0171     unittest.main(testRunner=testRunner, verbosity=2)