0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import json
0019 import sys
0020 import os
0021 import time
0022 import uuid
0023 import warnings
0024
0025 if sys.version > '3':
0026 basestring = str
0027 unicode = str
0028 long = int
0029
0030 from pyspark import SparkContext, since
0031 from pyspark.ml.common import inherit_doc
0032 from pyspark.sql import SparkSession
0033 from pyspark.util import VersionUtils
0034
0035
0036 def _jvm():
0037 """
0038 Returns the JVM view associated with SparkContext. Must be called
0039 after SparkContext is initialized.
0040 """
0041 jvm = SparkContext._jvm
0042 if jvm:
0043 return jvm
0044 else:
0045 raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
0046
0047
0048 class Identifiable(object):
0049 """
0050 Object with a unique ID.
0051 """
0052
0053 def __init__(self):
0054
0055 self.uid = self._randomUID()
0056
0057 def __repr__(self):
0058 return self.uid
0059
0060 @classmethod
0061 def _randomUID(cls):
0062 """
0063 Generate a unique unicode id for the object. The default implementation
0064 concatenates the class name, "_", and 12 random hex chars.
0065 """
0066 return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:])
0067
0068
0069 @inherit_doc
0070 class BaseReadWrite(object):
0071 """
0072 Base class for MLWriter and MLReader. Stores information about the SparkContext
0073 and SparkSession.
0074
0075 .. versionadded:: 2.3.0
0076 """
0077
0078 def __init__(self):
0079 self._sparkSession = None
0080
0081 def session(self, sparkSession):
0082 """
0083 Sets the Spark Session to use for saving/loading.
0084 """
0085 self._sparkSession = sparkSession
0086 return self
0087
0088 @property
0089 def sparkSession(self):
0090 """
0091 Returns the user-specified Spark Session or the default.
0092 """
0093 if self._sparkSession is None:
0094 self._sparkSession = SparkSession.builder.getOrCreate()
0095 return self._sparkSession
0096
0097 @property
0098 def sc(self):
0099 """
0100 Returns the underlying `SparkContext`.
0101 """
0102 return self.sparkSession.sparkContext
0103
0104
0105 @inherit_doc
0106 class MLWriter(BaseReadWrite):
0107 """
0108 Utility class that can save ML instances.
0109
0110 .. versionadded:: 2.0.0
0111 """
0112
0113 def __init__(self):
0114 super(MLWriter, self).__init__()
0115 self.shouldOverwrite = False
0116
0117 def _handleOverwrite(self, path):
0118 from pyspark.ml.wrapper import JavaWrapper
0119
0120 _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
0121 wrapper = JavaWrapper(_java_obj)
0122 wrapper._call_java("handleOverwrite", path, True, self.sparkSession._jsparkSession)
0123
0124 def save(self, path):
0125 """Save the ML instance to the input path."""
0126 if self.shouldOverwrite:
0127 self._handleOverwrite(path)
0128 self.saveImpl(path)
0129
0130 def saveImpl(self, path):
0131 """
0132 save() handles overwriting and then calls this method. Subclasses should override this
0133 method to implement the actual saving of the instance.
0134 """
0135 raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
0136
0137 def overwrite(self):
0138 """Overwrites if the output path already exists."""
0139 self.shouldOverwrite = True
0140 return self
0141
0142
0143 @inherit_doc
0144 class GeneralMLWriter(MLWriter):
0145 """
0146 Utility class that can save ML instances in different formats.
0147
0148 .. versionadded:: 2.4.0
0149 """
0150
0151 def format(self, source):
0152 """
0153 Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
0154 name for export).
0155 """
0156 self.source = source
0157 return self
0158
0159
0160 @inherit_doc
0161 class JavaMLWriter(MLWriter):
0162 """
0163 (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
0164 """
0165
0166 def __init__(self, instance):
0167 super(JavaMLWriter, self).__init__()
0168 _java_obj = instance._to_java()
0169 self._jwrite = _java_obj.write()
0170
0171 def save(self, path):
0172 """Save the ML instance to the input path."""
0173 if not isinstance(path, basestring):
0174 raise TypeError("path should be a basestring, got type %s" % type(path))
0175 self._jwrite.save(path)
0176
0177 def overwrite(self):
0178 """Overwrites if the output path already exists."""
0179 self._jwrite.overwrite()
0180 return self
0181
0182 def option(self, key, value):
0183 self._jwrite.option(key, value)
0184 return self
0185
0186 def session(self, sparkSession):
0187 """Sets the Spark Session to use for saving."""
0188 self._jwrite.session(sparkSession._jsparkSession)
0189 return self
0190
0191
0192 @inherit_doc
0193 class GeneralJavaMLWriter(JavaMLWriter):
0194 """
0195 (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types
0196 """
0197
0198 def __init__(self, instance):
0199 super(GeneralJavaMLWriter, self).__init__(instance)
0200
0201 def format(self, source):
0202 """
0203 Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
0204 name for export).
0205 """
0206 self._jwrite.format(source)
0207 return self
0208
0209
0210 @inherit_doc
0211 class MLWritable(object):
0212 """
0213 Mixin for ML instances that provide :py:class:`MLWriter`.
0214
0215 .. versionadded:: 2.0.0
0216 """
0217
0218 def write(self):
0219 """Returns an MLWriter instance for this ML instance."""
0220 raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
0221
0222 def save(self, path):
0223 """Save this ML instance to the given path, a shortcut of 'write().save(path)'."""
0224 self.write().save(path)
0225
0226
0227 @inherit_doc
0228 class JavaMLWritable(MLWritable):
0229 """
0230 (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
0231 """
0232
0233 def write(self):
0234 """Returns an MLWriter instance for this ML instance."""
0235 return JavaMLWriter(self)
0236
0237
0238 @inherit_doc
0239 class GeneralJavaMLWritable(JavaMLWritable):
0240 """
0241 (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
0242 """
0243
0244 def write(self):
0245 """Returns an GeneralMLWriter instance for this ML instance."""
0246 return GeneralJavaMLWriter(self)
0247
0248
0249 @inherit_doc
0250 class MLReader(BaseReadWrite):
0251 """
0252 Utility class that can load ML instances.
0253
0254 .. versionadded:: 2.0.0
0255 """
0256
0257 def __init__(self):
0258 super(MLReader, self).__init__()
0259
0260 def load(self, path):
0261 """Load the ML instance from the input path."""
0262 raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
0263
0264
0265 @inherit_doc
0266 class JavaMLReader(MLReader):
0267 """
0268 (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
0269 """
0270
0271 def __init__(self, clazz):
0272 super(JavaMLReader, self).__init__()
0273 self._clazz = clazz
0274 self._jread = self._load_java_obj(clazz).read()
0275
0276 def load(self, path):
0277 """Load the ML instance from the input path."""
0278 if not isinstance(path, basestring):
0279 raise TypeError("path should be a basestring, got type %s" % type(path))
0280 java_obj = self._jread.load(path)
0281 if not hasattr(self._clazz, "_from_java"):
0282 raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
0283 % self._clazz)
0284 return self._clazz._from_java(java_obj)
0285
0286 def session(self, sparkSession):
0287 """Sets the Spark Session to use for loading."""
0288 self._jread.session(sparkSession._jsparkSession)
0289 return self
0290
0291 @classmethod
0292 def _java_loader_class(cls, clazz):
0293 """
0294 Returns the full class name of the Java ML instance. The default
0295 implementation replaces "pyspark" by "org.apache.spark" in
0296 the Python full class name.
0297 """
0298 java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
0299 if clazz.__name__ in ("Pipeline", "PipelineModel"):
0300
0301 java_package = ".".join(java_package.split(".")[0:-1])
0302 return java_package + "." + clazz.__name__
0303
0304 @classmethod
0305 def _load_java_obj(cls, clazz):
0306 """Load the peer Java object of the ML instance."""
0307 java_class = cls._java_loader_class(clazz)
0308 java_obj = _jvm()
0309 for name in java_class.split("."):
0310 java_obj = getattr(java_obj, name)
0311 return java_obj
0312
0313
0314 @inherit_doc
0315 class MLReadable(object):
0316 """
0317 Mixin for instances that provide :py:class:`MLReader`.
0318
0319 .. versionadded:: 2.0.0
0320 """
0321
0322 @classmethod
0323 def read(cls):
0324 """Returns an MLReader instance for this class."""
0325 raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
0326
0327 @classmethod
0328 def load(cls, path):
0329 """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
0330 return cls.read().load(path)
0331
0332
0333 @inherit_doc
0334 class JavaMLReadable(MLReadable):
0335 """
0336 (Private) Mixin for instances that provide JavaMLReader.
0337 """
0338
0339 @classmethod
0340 def read(cls):
0341 """Returns an MLReader instance for this class."""
0342 return JavaMLReader(cls)
0343
0344
0345 @inherit_doc
0346 class DefaultParamsWritable(MLWritable):
0347 """
0348 Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params`
0349 class stores all data as :py:class:`Param` values, then extending this trait will provide
0350 a default implementation of writing saved instances of the class.
0351 This only handles simple :py:class:`Param` types; e.g., it will not handle
0352 :py:class:`Dataset`. See :py:class:`DefaultParamsReadable`, the counterpart to this trait.
0353
0354 .. versionadded:: 2.3.0
0355 """
0356
0357 def write(self):
0358 """Returns a DefaultParamsWriter instance for this class."""
0359 from pyspark.ml.param import Params
0360
0361 if isinstance(self, Params):
0362 return DefaultParamsWriter(self)
0363 else:
0364 raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " +
0365 " extend Params.", type(self))
0366
0367
0368 @inherit_doc
0369 class DefaultParamsWriter(MLWriter):
0370 """
0371 Specialization of :py:class:`MLWriter` for :py:class:`Params` types
0372
0373 Class for writing Estimators and Transformers whose parameters are JSON-serializable.
0374
0375 .. versionadded:: 2.3.0
0376 """
0377
0378 def __init__(self, instance):
0379 super(DefaultParamsWriter, self).__init__()
0380 self.instance = instance
0381
0382 def saveImpl(self, path):
0383 DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
0384
0385 @staticmethod
0386 def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
0387 """
0388 Saves metadata + Params to: path + "/metadata"
0389
0390 - class
0391 - timestamp
0392 - sparkVersion
0393 - uid
0394 - paramMap
0395 - defaultParamMap (since 2.4.0)
0396 - (optionally, extra metadata)
0397
0398 :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
0399 :param paramMap: If given, this is saved in the "paramMap" field.
0400 """
0401 metadataPath = os.path.join(path, "metadata")
0402 metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
0403 sc,
0404 extraMetadata,
0405 paramMap)
0406 sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
0407
0408 @staticmethod
0409 def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
0410 """
0411 Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save.
0412 This is useful for ensemble models which need to save metadata for many sub-models.
0413
0414 .. note:: :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes.
0415 """
0416 uid = instance.uid
0417 cls = instance.__module__ + '.' + instance.__class__.__name__
0418
0419
0420 params = instance._paramMap
0421 jsonParams = {}
0422 if paramMap is not None:
0423 jsonParams = paramMap
0424 else:
0425 for p in params:
0426 jsonParams[p.name] = params[p]
0427
0428
0429 jsonDefaultParams = {}
0430 for p in instance._defaultParamMap:
0431 jsonDefaultParams[p.name] = instance._defaultParamMap[p]
0432
0433 basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
0434 "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
0435 "defaultParamMap": jsonDefaultParams}
0436 if extraMetadata is not None:
0437 basicMetadata.update(extraMetadata)
0438 return json.dumps(basicMetadata, separators=[',', ':'])
0439
0440
0441 @inherit_doc
0442 class DefaultParamsReadable(MLReadable):
0443 """
0444 Helper trait for making simple :py:class:`Params` types readable.
0445 If a :py:class:`Params` class stores all data as :py:class:`Param` values,
0446 then extending this trait will provide a default implementation of reading saved
0447 instances of the class. This only handles simple :py:class:`Param` types;
0448 e.g., it will not handle :py:class:`Dataset`. See :py:class:`DefaultParamsWritable`,
0449 the counterpart to this trait.
0450
0451 .. versionadded:: 2.3.0
0452 """
0453
0454 @classmethod
0455 def read(cls):
0456 """Returns a DefaultParamsReader instance for this class."""
0457 return DefaultParamsReader(cls)
0458
0459
0460 @inherit_doc
0461 class DefaultParamsReader(MLReader):
0462 """
0463 Specialization of :py:class:`MLReader` for :py:class:`Params` types
0464
0465 Default :py:class:`MLReader` implementation for transformers and estimators that
0466 contain basic (json-serializable) params and no data. This will not handle
0467 more complex params or types with data (e.g., models with coefficients).
0468
0469 .. versionadded:: 2.3.0
0470 """
0471
0472 def __init__(self, cls):
0473 super(DefaultParamsReader, self).__init__()
0474 self.cls = cls
0475
0476 @staticmethod
0477 def __get_class(clazz):
0478 """
0479 Loads Python class from its name.
0480 """
0481 parts = clazz.split('.')
0482 module = ".".join(parts[:-1])
0483 m = __import__(module)
0484 for comp in parts[1:]:
0485 m = getattr(m, comp)
0486 return m
0487
0488 def load(self, path):
0489 metadata = DefaultParamsReader.loadMetadata(path, self.sc)
0490 py_type = DefaultParamsReader.__get_class(metadata['class'])
0491 instance = py_type()
0492 instance._resetUid(metadata['uid'])
0493 DefaultParamsReader.getAndSetParams(instance, metadata)
0494 return instance
0495
0496 @staticmethod
0497 def loadMetadata(path, sc, expectedClassName=""):
0498 """
0499 Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
0500
0501 :param expectedClassName: If non empty, this is checked against the loaded metadata.
0502 """
0503 metadataPath = os.path.join(path, "metadata")
0504 metadataStr = sc.textFile(metadataPath, 1).first()
0505 loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
0506 return loadedVals
0507
0508 @staticmethod
0509 def _parseMetaData(metadataStr, expectedClassName=""):
0510 """
0511 Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`.
0512 This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`.
0513
0514 :param metadataStr: JSON string of metadata
0515 :param expectedClassName: If non empty, this is checked against the loaded metadata.
0516 """
0517 metadata = json.loads(metadataStr)
0518 className = metadata['class']
0519 if len(expectedClassName) > 0:
0520 assert className == expectedClassName, "Error loading metadata: Expected " + \
0521 "class name {} but found class name {}".format(expectedClassName, className)
0522 return metadata
0523
0524 @staticmethod
0525 def getAndSetParams(instance, metadata):
0526 """
0527 Extract Params from metadata, and set them in the instance.
0528 """
0529
0530 for paramName in metadata['paramMap']:
0531 param = instance.getParam(paramName)
0532 paramValue = metadata['paramMap'][paramName]
0533 instance.set(param, paramValue)
0534
0535
0536 majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
0537 major = majorAndMinorVersions[0]
0538 minor = majorAndMinorVersions[1]
0539
0540
0541 if major > 2 or (major == 2 and minor >= 4):
0542 assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
0543 "`defaultParamMap` section not found"
0544
0545 for paramName in metadata['defaultParamMap']:
0546 paramValue = metadata['defaultParamMap'][paramName]
0547 instance._setDefault(**{paramName: paramValue})
0548
0549 @staticmethod
0550 def loadParamsInstance(path, sc):
0551 """
0552 Load a :py:class:`Params` instance from the given path, and return it.
0553 This assumes the instance inherits from :py:class:`MLReadable`.
0554 """
0555 metadata = DefaultParamsReader.loadMetadata(path, sc)
0556 pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
0557 py_type = DefaultParamsReader.__get_class(pythonClassName)
0558 instance = py_type.load(path)
0559 return instance
0560
0561
0562 @inherit_doc
0563 class HasTrainingSummary(object):
0564 """
0565 Base class for models that provides Training summary.
0566
0567 .. versionadded:: 3.0.0
0568 """
0569
0570 @property
0571 @since("2.1.0")
0572 def hasSummary(self):
0573 """
0574 Indicates whether a training summary exists for this model
0575 instance.
0576 """
0577 return self._call_java("hasSummary")
0578
0579 @property
0580 @since("2.1.0")
0581 def summary(self):
0582 """
0583 Gets summary of the model trained on the training set. An exception is thrown if
0584 no summary exists.
0585 """
0586 return (self._call_java("summary"))