0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import shutil
0019 import tempfile
0020
0021 from pyspark.sql import Row
0022 from pyspark.sql.types import *
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024
0025
0026 class DataSourcesTests(ReusedSQLTestCase):
0027
0028 def test_linesep_text(self):
0029 df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
0030 expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
0031 Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
0032 Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
0033 Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
0034 self.assertEqual(df.collect(), expected)
0035
0036 tpath = tempfile.mkdtemp()
0037 shutil.rmtree(tpath)
0038 try:
0039 df.write.text(tpath, lineSep="!")
0040 expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
0041 Row(value=u'Tom!30!"My name is Tom"'),
0042 Row(value=u'Hyukjin!25!"I am Hyukjin'),
0043 Row(value=u''), Row(value=u'I love Spark!"'),
0044 Row(value=u'!')]
0045 readback = self.spark.read.text(tpath)
0046 self.assertEqual(readback.collect(), expected)
0047 finally:
0048 shutil.rmtree(tpath)
0049
0050 def test_multiline_json(self):
0051 people1 = self.spark.read.json("python/test_support/sql/people.json")
0052 people_array = self.spark.read.json("python/test_support/sql/people_array.json",
0053 multiLine=True)
0054 self.assertEqual(people1.collect(), people_array.collect())
0055
0056 def test_encoding_json(self):
0057 people_array = self.spark.read\
0058 .json("python/test_support/sql/people_array_utf16le.json",
0059 multiLine=True, encoding="UTF-16LE")
0060 expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
0061 self.assertEqual(people_array.collect(), expected)
0062
0063 def test_linesep_json(self):
0064 df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
0065 expected = [Row(_corrupt_record=None, name=u'Michael'),
0066 Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
0067 Row(_corrupt_record=u' "age":19}\n', name=None)]
0068 self.assertEqual(df.collect(), expected)
0069
0070 tpath = tempfile.mkdtemp()
0071 shutil.rmtree(tpath)
0072 try:
0073 df = self.spark.read.json("python/test_support/sql/people.json")
0074 df.write.json(tpath, lineSep="!!")
0075 readback = self.spark.read.json(tpath, lineSep="!!")
0076 self.assertEqual(readback.collect(), df.collect())
0077 finally:
0078 shutil.rmtree(tpath)
0079
0080 def test_multiline_csv(self):
0081 ages_newlines = self.spark.read.csv(
0082 "python/test_support/sql/ages_newlines.csv", multiLine=True)
0083 expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
0084 Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
0085 Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
0086 self.assertEqual(ages_newlines.collect(), expected)
0087
0088 def test_ignorewhitespace_csv(self):
0089 tmpPath = tempfile.mkdtemp()
0090 shutil.rmtree(tmpPath)
0091 self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv(
0092 tmpPath,
0093 ignoreLeadingWhiteSpace=False,
0094 ignoreTrailingWhiteSpace=False)
0095
0096 expected = [Row(value=u' a,b , c ')]
0097 readback = self.spark.read.text(tmpPath)
0098 self.assertEqual(readback.collect(), expected)
0099 shutil.rmtree(tmpPath)
0100
0101 def test_read_multiple_orc_file(self):
0102 df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
0103 "python/test_support/sql/orc_partitioned/b=1/c=1"])
0104 self.assertEqual(2, df.count())
0105
0106 def test_read_text_file_list(self):
0107 df = self.spark.read.text(['python/test_support/sql/text-test.txt',
0108 'python/test_support/sql/text-test.txt'])
0109 count = df.count()
0110 self.assertEquals(count, 4)
0111
0112 def test_json_sampling_ratio(self):
0113 rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
0114 .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
0115 schema = self.spark.read.option('inferSchema', True) \
0116 .option('samplingRatio', 0.5) \
0117 .json(rdd).schema
0118 self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
0119
0120 def test_csv_sampling_ratio(self):
0121 rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
0122 .map(lambda x: '0.1' if x == 1 else str(x))
0123 schema = self.spark.read.option('inferSchema', True)\
0124 .csv(rdd, samplingRatio=0.5).schema
0125 self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
0126
0127 def test_checking_csv_header(self):
0128 path = tempfile.mkdtemp()
0129 shutil.rmtree(path)
0130 try:
0131 self.spark.createDataFrame([[1, 1000], [2000, 2]])\
0132 .toDF('f1', 'f2').write.option("header", "true").csv(path)
0133 schema = StructType([
0134 StructField('f2', IntegerType(), nullable=True),
0135 StructField('f1', IntegerType(), nullable=True)])
0136 df = self.spark.read.option('header', 'true').schema(schema)\
0137 .csv(path, enforceSchema=False)
0138 self.assertRaisesRegexp(
0139 Exception,
0140 "CSV header does not conform to the schema",
0141 lambda: df.collect())
0142 finally:
0143 shutil.rmtree(path)
0144
0145 def test_ignore_column_of_all_nulls(self):
0146 path = tempfile.mkdtemp()
0147 shutil.rmtree(path)
0148 try:
0149 df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""],
0150 ["""{"a":null, "b":null, "c":"string"}"""],
0151 ["""{"a":null, "b":null, "c":null}"""]])
0152 df.write.text(path)
0153 schema = StructType([
0154 StructField('b', LongType(), nullable=True),
0155 StructField('c', StringType(), nullable=True)])
0156 readback = self.spark.read.json(path, dropFieldIfAllNull=True)
0157 self.assertEquals(readback.schema, schema)
0158 finally:
0159 shutil.rmtree(path)
0160
0161
0162 if __name__ == "__main__":
0163 import unittest
0164 from pyspark.sql.tests.test_datasources import *
0165
0166 try:
0167 import xmlrunner
0168 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0169 except ImportError:
0170 testRunner = None
0171 unittest.main(testRunner=testRunner, verbosity=2)