Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import unittest
0018 
0019 import py4j
0020 
0021 from pyspark.ml.image import ImageSchema
0022 from pyspark.testing.mlutils import PySparkTestCase, SparkSessionTestCase
0023 from pyspark.sql import Row
0024 from pyspark.testing.utils import QuietTest
0025 
0026 
0027 class ImageFileFormatTest(SparkSessionTestCase):
0028 
0029     def test_read_images(self):
0030         data_path = 'data/mllib/images/origin/kittens'
0031         df = self.spark.read.format("image") \
0032             .option("dropInvalid", True) \
0033             .option("recursiveFileLookup", True) \
0034             .load(data_path)
0035         self.assertEqual(df.count(), 4)
0036         first_row = df.take(1)[0][0]
0037         # compare `schema.simpleString()` instead of directly compare schema,
0038         # because the df loaded from datasouce may change schema column nullability.
0039         self.assertEqual(df.schema.simpleString(), ImageSchema.imageSchema.simpleString())
0040         self.assertEqual(df.schema["image"].dataType.simpleString(),
0041                          ImageSchema.columnSchema.simpleString())
0042         array = ImageSchema.toNDArray(first_row)
0043         self.assertEqual(len(array), first_row[1])
0044         self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row)
0045         expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24}
0046         self.assertEqual(ImageSchema.ocvTypes, expected)
0047         expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data']
0048         self.assertEqual(ImageSchema.imageFields, expected)
0049         self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
0050 
0051         with QuietTest(self.sc):
0052             self.assertRaisesRegexp(
0053                 TypeError,
0054                 "image argument should be pyspark.sql.types.Row; however",
0055                 lambda: ImageSchema.toNDArray("a"))
0056 
0057         with QuietTest(self.sc):
0058             self.assertRaisesRegexp(
0059                 ValueError,
0060                 "image argument should have attributes specified in",
0061                 lambda: ImageSchema.toNDArray(Row(a=1)))
0062 
0063         with QuietTest(self.sc):
0064             self.assertRaisesRegexp(
0065                 TypeError,
0066                 "array argument should be numpy.ndarray; however, it got",
0067                 lambda: ImageSchema.toImage("a"))
0068 
0069 
0070 if __name__ == "__main__":
0071     from pyspark.ml.tests.test_image import *
0072 
0073     try:
0074         import xmlrunner
0075         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0076     except ImportError:
0077         testRunner = None
0078     unittest.main(testRunner=testRunner, verbosity=2)