0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 .. attribute:: ImageSchema
0020
0021 An attribute of this module that contains the instance of :class:`_ImageSchema`.
0022
0023 .. autoclass:: _ImageSchema
0024 :members:
0025 """
0026
0027 import sys
0028 import warnings
0029
0030 import numpy as np
0031 from distutils.version import LooseVersion
0032
0033 from pyspark import SparkContext
0034 from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
0035 from pyspark.sql import DataFrame, SparkSession
0036
0037 __all__ = ["ImageSchema"]
0038
0039
0040 class _ImageSchema(object):
0041 """
0042 Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to be private and
0043 not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to access the
0044 APIs of this class.
0045 """
0046
0047 def __init__(self):
0048 self._imageSchema = None
0049 self._ocvTypes = None
0050 self._columnSchema = None
0051 self._imageFields = None
0052 self._undefinedImageType = None
0053
0054 @property
0055 def imageSchema(self):
0056 """
0057 Returns the image schema.
0058
0059 :return: a :class:`StructType` with a single column of images
0060 named "image" (nullable) and having the same type returned by :meth:`columnSchema`.
0061
0062 .. versionadded:: 2.3.0
0063 """
0064
0065 if self._imageSchema is None:
0066 ctx = SparkContext._active_spark_context
0067 jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema()
0068 self._imageSchema = _parse_datatype_json_string(jschema.json())
0069 return self._imageSchema
0070
0071 @property
0072 def ocvTypes(self):
0073 """
0074 Returns the OpenCV type mapping supported.
0075
0076 :return: a dictionary containing the OpenCV type mapping supported.
0077
0078 .. versionadded:: 2.3.0
0079 """
0080
0081 if self._ocvTypes is None:
0082 ctx = SparkContext._active_spark_context
0083 self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes())
0084 return self._ocvTypes
0085
0086 @property
0087 def columnSchema(self):
0088 """
0089 Returns the schema for the image column.
0090
0091 :return: a :class:`StructType` for image column,
0092 ``struct<origin:string, height:int, width:int, nChannels:int, mode:int, data:binary>``.
0093
0094 .. versionadded:: 2.4.0
0095 """
0096
0097 if self._columnSchema is None:
0098 ctx = SparkContext._active_spark_context
0099 jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.columnSchema()
0100 self._columnSchema = _parse_datatype_json_string(jschema.json())
0101 return self._columnSchema
0102
0103 @property
0104 def imageFields(self):
0105 """
0106 Returns field names of image columns.
0107
0108 :return: a list of field names.
0109
0110 .. versionadded:: 2.3.0
0111 """
0112
0113 if self._imageFields is None:
0114 ctx = SparkContext._active_spark_context
0115 self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields())
0116 return self._imageFields
0117
0118 @property
0119 def undefinedImageType(self):
0120 """
0121 Returns the name of undefined image type for the invalid image.
0122
0123 .. versionadded:: 2.3.0
0124 """
0125
0126 if self._undefinedImageType is None:
0127 ctx = SparkContext._active_spark_context
0128 self._undefinedImageType = \
0129 ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
0130 return self._undefinedImageType
0131
0132 def toNDArray(self, image):
0133 """
0134 Converts an image to an array with metadata.
0135
0136 :param `Row` image: A row that contains the image to be converted. It should
0137 have the attributes specified in `ImageSchema.imageSchema`.
0138 :return: a `numpy.ndarray` that is an image.
0139
0140 .. versionadded:: 2.3.0
0141 """
0142
0143 if not isinstance(image, Row):
0144 raise TypeError(
0145 "image argument should be pyspark.sql.types.Row; however, "
0146 "it got [%s]." % type(image))
0147
0148 if any(not hasattr(image, f) for f in self.imageFields):
0149 raise ValueError(
0150 "image argument should have attributes specified in "
0151 "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))
0152
0153 height = image.height
0154 width = image.width
0155 nChannels = image.nChannels
0156 return np.ndarray(
0157 shape=(height, width, nChannels),
0158 dtype=np.uint8,
0159 buffer=image.data,
0160 strides=(width * nChannels, nChannels, 1))
0161
0162 def toImage(self, array, origin=""):
0163 """
0164 Converts an array with metadata to a two-dimensional image.
0165
0166 :param `numpy.ndarray` array: The array to convert to image.
0167 :param str origin: Path to the image, optional.
0168 :return: a :class:`Row` that is a two dimensional image.
0169
0170 .. versionadded:: 2.3.0
0171 """
0172
0173 if not isinstance(array, np.ndarray):
0174 raise TypeError(
0175 "array argument should be numpy.ndarray; however, it got [%s]." % type(array))
0176
0177 if array.ndim != 3:
0178 raise ValueError("Invalid array shape")
0179
0180 height, width, nChannels = array.shape
0181 ocvTypes = ImageSchema.ocvTypes
0182 if nChannels == 1:
0183 mode = ocvTypes["CV_8UC1"]
0184 elif nChannels == 3:
0185 mode = ocvTypes["CV_8UC3"]
0186 elif nChannels == 4:
0187 mode = ocvTypes["CV_8UC4"]
0188 else:
0189 raise ValueError("Invalid number of channels")
0190
0191
0192
0193
0194 if LooseVersion(np.__version__) >= LooseVersion('1.9'):
0195 data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
0196 else:
0197
0198 data = bytearray(array.astype(dtype=np.uint8).ravel())
0199
0200
0201
0202
0203 return _create_row(self.imageFields,
0204 [origin, height, width, nChannels, mode, data])
0205
0206
0207 ImageSchema = _ImageSchema()
0208
0209
0210
0211 def _disallow_instance(_):
0212 raise RuntimeError("Creating instance of _ImageSchema class is disallowed.")
0213 _ImageSchema.__init__ = _disallow_instance
0214
0215
0216 def _test():
0217 import doctest
0218 import pyspark.ml.image
0219 globs = pyspark.ml.image.__dict__.copy()
0220 spark = SparkSession.builder\
0221 .master("local[2]")\
0222 .appName("ml.image tests")\
0223 .getOrCreate()
0224 globs['spark'] = spark
0225
0226 (failure_count, test_count) = doctest.testmod(
0227 pyspark.ml.image, globs=globs,
0228 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0229 spark.stop()
0230 if failure_count:
0231 sys.exit(-1)
0232
0233
0234 if __name__ == "__main__":
0235 _test()