0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import unittest
0018
0019 import py4j
0020
0021 from pyspark.ml.image import ImageSchema
0022 from pyspark.testing.mlutils import PySparkTestCase, SparkSessionTestCase
0023 from pyspark.sql import Row
0024 from pyspark.testing.utils import QuietTest
0025
0026
0027 class ImageFileFormatTest(SparkSessionTestCase):
0028
0029 def test_read_images(self):
0030 data_path = 'data/mllib/images/origin/kittens'
0031 df = self.spark.read.format("image") \
0032 .option("dropInvalid", True) \
0033 .option("recursiveFileLookup", True) \
0034 .load(data_path)
0035 self.assertEqual(df.count(), 4)
0036 first_row = df.take(1)[0][0]
0037
0038
0039 self.assertEqual(df.schema.simpleString(), ImageSchema.imageSchema.simpleString())
0040 self.assertEqual(df.schema["image"].dataType.simpleString(),
0041 ImageSchema.columnSchema.simpleString())
0042 array = ImageSchema.toNDArray(first_row)
0043 self.assertEqual(len(array), first_row[1])
0044 self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row)
0045 expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24}
0046 self.assertEqual(ImageSchema.ocvTypes, expected)
0047 expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data']
0048 self.assertEqual(ImageSchema.imageFields, expected)
0049 self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
0050
0051 with QuietTest(self.sc):
0052 self.assertRaisesRegexp(
0053 TypeError,
0054 "image argument should be pyspark.sql.types.Row; however",
0055 lambda: ImageSchema.toNDArray("a"))
0056
0057 with QuietTest(self.sc):
0058 self.assertRaisesRegexp(
0059 ValueError,
0060 "image argument should have attributes specified in",
0061 lambda: ImageSchema.toNDArray(Row(a=1)))
0062
0063 with QuietTest(self.sc):
0064 self.assertRaisesRegexp(
0065 TypeError,
0066 "array argument should be numpy.ndarray; however, it got",
0067 lambda: ImageSchema.toImage("a"))
0068
0069
0070 if __name__ == "__main__":
0071 from pyspark.ml.tests.test_image import *
0072
0073 try:
0074 import xmlrunner
0075 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0076 except ImportError:
0077 testRunner = None
0078 unittest.main(testRunner=testRunner, verbosity=2)