0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import array
0018 import sys
0019 if sys.version > '3':
0020 basestring = str
0021 xrange = range
0022 unicode = str
0023
0024 from abc import ABCMeta
0025 import copy
0026 import numpy as np
0027
0028 from py4j.java_gateway import JavaObject
0029
0030 from pyspark.ml.linalg import DenseVector, Vector, Matrix
0031 from pyspark.ml.util import Identifiable
0032
0033
0034 __all__ = ['Param', 'Params', 'TypeConverters']
0035
0036
0037 class Param(object):
0038 """
0039 A param with self-contained documentation.
0040
0041 .. versionadded:: 1.3.0
0042 """
0043
0044 def __init__(self, parent, name, doc, typeConverter=None):
0045 if not isinstance(parent, Identifiable):
0046 raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
0047 self.parent = parent.uid
0048 self.name = str(name)
0049 self.doc = str(doc)
0050 self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter
0051
0052 def _copy_new_parent(self, parent):
0053 """Copy the current param to a new parent, must be a dummy param."""
0054 if self.parent == "undefined":
0055 param = copy.copy(self)
0056 param.parent = parent.uid
0057 return param
0058 else:
0059 raise ValueError("Cannot copy from non-dummy parent %s." % parent)
0060
0061 def __str__(self):
0062 return str(self.parent) + "__" + self.name
0063
0064 def __repr__(self):
0065 return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
0066
0067 def __hash__(self):
0068 return hash(str(self))
0069
0070 def __eq__(self, other):
0071 if isinstance(other, Param):
0072 return self.parent == other.parent and self.name == other.name
0073 else:
0074 return False
0075
0076
0077 class TypeConverters(object):
0078 """
0079 Factory methods for common type conversion functions for `Param.typeConverter`.
0080
0081 .. versionadded:: 2.0.0
0082 """
0083
0084 @staticmethod
0085 def _is_numeric(value):
0086 vtype = type(value)
0087 return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
0088
0089 @staticmethod
0090 def _is_integer(value):
0091 return TypeConverters._is_numeric(value) and float(value).is_integer()
0092
0093 @staticmethod
0094 def _can_convert_to_list(value):
0095 vtype = type(value)
0096 return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector)
0097
0098 @staticmethod
0099 def _can_convert_to_string(value):
0100 vtype = type(value)
0101 return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_]
0102
0103 @staticmethod
0104 def identity(value):
0105 """
0106 Dummy converter that just returns value.
0107 """
0108 return value
0109
0110 @staticmethod
0111 def toList(value):
0112 """
0113 Convert a value to a list, if possible.
0114 """
0115 if type(value) == list:
0116 return value
0117 elif type(value) in [np.ndarray, tuple, xrange, array.array]:
0118 return list(value)
0119 elif isinstance(value, Vector):
0120 return list(value.toArray())
0121 else:
0122 raise TypeError("Could not convert %s to list" % value)
0123
0124 @staticmethod
0125 def toListFloat(value):
0126 """
0127 Convert a value to list of floats, if possible.
0128 """
0129 if TypeConverters._can_convert_to_list(value):
0130 value = TypeConverters.toList(value)
0131 if all(map(lambda v: TypeConverters._is_numeric(v), value)):
0132 return [float(v) for v in value]
0133 raise TypeError("Could not convert %s to list of floats" % value)
0134
0135 @staticmethod
0136 def toListListFloat(value):
0137 """
0138 Convert a value to list of list of floats, if possible.
0139 """
0140 if TypeConverters._can_convert_to_list(value):
0141 value = TypeConverters.toList(value)
0142 return [TypeConverters.toListFloat(v) for v in value]
0143 raise TypeError("Could not convert %s to list of list of floats" % value)
0144
0145 @staticmethod
0146 def toListInt(value):
0147 """
0148 Convert a value to list of ints, if possible.
0149 """
0150 if TypeConverters._can_convert_to_list(value):
0151 value = TypeConverters.toList(value)
0152 if all(map(lambda v: TypeConverters._is_integer(v), value)):
0153 return [int(v) for v in value]
0154 raise TypeError("Could not convert %s to list of ints" % value)
0155
0156 @staticmethod
0157 def toListString(value):
0158 """
0159 Convert a value to list of strings, if possible.
0160 """
0161 if TypeConverters._can_convert_to_list(value):
0162 value = TypeConverters.toList(value)
0163 if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)):
0164 return [TypeConverters.toString(v) for v in value]
0165 raise TypeError("Could not convert %s to list of strings" % value)
0166
0167 @staticmethod
0168 def toVector(value):
0169 """
0170 Convert a value to a MLlib Vector, if possible.
0171 """
0172 if isinstance(value, Vector):
0173 return value
0174 elif TypeConverters._can_convert_to_list(value):
0175 value = TypeConverters.toList(value)
0176 if all(map(lambda v: TypeConverters._is_numeric(v), value)):
0177 return DenseVector(value)
0178 raise TypeError("Could not convert %s to vector" % value)
0179
0180 @staticmethod
0181 def toMatrix(value):
0182 """
0183 Convert a value to a MLlib Matrix, if possible.
0184 """
0185 if isinstance(value, Matrix):
0186 return value
0187 raise TypeError("Could not convert %s to matrix" % value)
0188
0189 @staticmethod
0190 def toFloat(value):
0191 """
0192 Convert a value to a float, if possible.
0193 """
0194 if TypeConverters._is_numeric(value):
0195 return float(value)
0196 else:
0197 raise TypeError("Could not convert %s to float" % value)
0198
0199 @staticmethod
0200 def toInt(value):
0201 """
0202 Convert a value to an int, if possible.
0203 """
0204 if TypeConverters._is_integer(value):
0205 return int(value)
0206 else:
0207 raise TypeError("Could not convert %s to int" % value)
0208
0209 @staticmethod
0210 def toString(value):
0211 """
0212 Convert a value to a string, if possible.
0213 """
0214 if isinstance(value, basestring):
0215 return value
0216 elif type(value) in [np.string_, np.str_]:
0217 return str(value)
0218 elif type(value) == np.unicode_:
0219 return unicode(value)
0220 else:
0221 raise TypeError("Could not convert %s to string type" % type(value))
0222
0223 @staticmethod
0224 def toBoolean(value):
0225 """
0226 Convert a value to a boolean, if possible.
0227 """
0228 if type(value) == bool:
0229 return value
0230 else:
0231 raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value))
0232
0233
0234 class Params(Identifiable):
0235 """
0236 Components that take parameters. This also provides an internal
0237 param map to store parameter values attached to the instance.
0238
0239 .. versionadded:: 1.3.0
0240 """
0241
0242 __metaclass__ = ABCMeta
0243
0244 def __init__(self):
0245 super(Params, self).__init__()
0246
0247 self._paramMap = {}
0248
0249
0250 self._defaultParamMap = {}
0251
0252
0253 self._params = None
0254
0255
0256 self._copy_params()
0257
0258 def _copy_params(self):
0259 """
0260 Copy all params defined on the class to current object.
0261 """
0262 cls = type(self)
0263 src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)]
0264 src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs))
0265 for name, param in src_params:
0266 setattr(self, name, param._copy_new_parent(self))
0267
0268 @property
0269 def params(self):
0270 """
0271 Returns all params ordered by name. The default implementation
0272 uses :py:func:`dir` to get all attributes of type
0273 :py:class:`Param`.
0274 """
0275 if self._params is None:
0276 self._params = list(filter(lambda attr: isinstance(attr, Param),
0277 [getattr(self, x) for x in dir(self) if x != "params" and
0278 not isinstance(getattr(type(self), x, None), property)]))
0279 return self._params
0280
0281 def explainParam(self, param):
0282 """
0283 Explains a single param and returns its name, doc, and optional
0284 default value and user-supplied value in a string.
0285 """
0286 param = self._resolveParam(param)
0287 values = []
0288 if self.isDefined(param):
0289 if param in self._defaultParamMap:
0290 values.append("default: %s" % self._defaultParamMap[param])
0291 if param in self._paramMap:
0292 values.append("current: %s" % self._paramMap[param])
0293 else:
0294 values.append("undefined")
0295 valueStr = "(" + ", ".join(values) + ")"
0296 return "%s: %s %s" % (param.name, param.doc, valueStr)
0297
0298 def explainParams(self):
0299 """
0300 Returns the documentation of all params with their optionally
0301 default values and user-supplied values.
0302 """
0303 return "\n".join([self.explainParam(param) for param in self.params])
0304
0305 def getParam(self, paramName):
0306 """
0307 Gets a param by its name.
0308 """
0309 param = getattr(self, paramName)
0310 if isinstance(param, Param):
0311 return param
0312 else:
0313 raise ValueError("Cannot find param with name %s." % paramName)
0314
0315 def isSet(self, param):
0316 """
0317 Checks whether a param is explicitly set by user.
0318 """
0319 param = self._resolveParam(param)
0320 return param in self._paramMap
0321
0322 def hasDefault(self, param):
0323 """
0324 Checks whether a param has a default value.
0325 """
0326 param = self._resolveParam(param)
0327 return param in self._defaultParamMap
0328
0329 def isDefined(self, param):
0330 """
0331 Checks whether a param is explicitly set by user or has
0332 a default value.
0333 """
0334 return self.isSet(param) or self.hasDefault(param)
0335
0336 def hasParam(self, paramName):
0337 """
0338 Tests whether this instance contains a param with a given
0339 (string) name.
0340 """
0341 if isinstance(paramName, basestring):
0342 p = getattr(self, paramName, None)
0343 return isinstance(p, Param)
0344 else:
0345 raise TypeError("hasParam(): paramName must be a string")
0346
0347 def getOrDefault(self, param):
0348 """
0349 Gets the value of a param in the user-supplied param map or its
0350 default value. Raises an error if neither is set.
0351 """
0352 param = self._resolveParam(param)
0353 if param in self._paramMap:
0354 return self._paramMap[param]
0355 else:
0356 return self._defaultParamMap[param]
0357
0358 def extractParamMap(self, extra=None):
0359 """
0360 Extracts the embedded default param values and user-supplied
0361 values, and then merges them with extra values from input into
0362 a flat param map, where the latter value is used if there exist
0363 conflicts, i.e., with ordering: default param values <
0364 user-supplied values < extra.
0365
0366 :param extra: extra param values
0367 :return: merged param map
0368 """
0369 if extra is None:
0370 extra = dict()
0371 paramMap = self._defaultParamMap.copy()
0372 paramMap.update(self._paramMap)
0373 paramMap.update(extra)
0374 return paramMap
0375
0376 def copy(self, extra=None):
0377 """
0378 Creates a copy of this instance with the same uid and some
0379 extra params. The default implementation creates a
0380 shallow copy using :py:func:`copy.copy`, and then copies the
0381 embedded and extra parameters over and returns the copy.
0382 Subclasses should override this method if the default approach
0383 is not sufficient.
0384
0385 :param extra: Extra parameters to copy to the new instance
0386 :return: Copy of this instance
0387 """
0388 if extra is None:
0389 extra = dict()
0390 that = copy.copy(self)
0391 that._paramMap = {}
0392 that._defaultParamMap = {}
0393 return self._copyValues(that, extra)
0394
0395 def set(self, param, value):
0396 """
0397 Sets a parameter in the embedded param map.
0398 """
0399 self._shouldOwn(param)
0400 try:
0401 value = param.typeConverter(value)
0402 except ValueError as e:
0403 raise ValueError('Invalid param value given for param "%s". %s' % (param.name, e))
0404 self._paramMap[param] = value
0405
0406 def _shouldOwn(self, param):
0407 """
0408 Validates that the input param belongs to this Params instance.
0409 """
0410 if not (self.uid == param.parent and self.hasParam(param.name)):
0411 raise ValueError("Param %r does not belong to %r." % (param, self))
0412
0413 def _resolveParam(self, param):
0414 """
0415 Resolves a param and validates the ownership.
0416
0417 :param param: param name or the param instance, which must
0418 belong to this Params instance
0419 :return: resolved param instance
0420 """
0421 if isinstance(param, Param):
0422 self._shouldOwn(param)
0423 return param
0424 elif isinstance(param, basestring):
0425 return self.getParam(param)
0426 else:
0427 raise ValueError("Cannot resolve %r as a param." % param)
0428
0429 @staticmethod
0430 def _dummy():
0431 """
0432 Returns a dummy Params instance used as a placeholder to
0433 generate docs.
0434 """
0435 dummy = Params()
0436 dummy.uid = "undefined"
0437 return dummy
0438
0439 def _set(self, **kwargs):
0440 """
0441 Sets user-supplied params.
0442 """
0443 for param, value in kwargs.items():
0444 p = getattr(self, param)
0445 if value is not None:
0446 try:
0447 value = p.typeConverter(value)
0448 except TypeError as e:
0449 raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
0450 self._paramMap[p] = value
0451 return self
0452
0453 def clear(self, param):
0454 """
0455 Clears a param from the param map if it has been explicitly set.
0456 """
0457 if self.isSet(param):
0458 del self._paramMap[param]
0459
0460 def _setDefault(self, **kwargs):
0461 """
0462 Sets default params.
0463 """
0464 for param, value in kwargs.items():
0465 p = getattr(self, param)
0466 if value is not None and not isinstance(value, JavaObject):
0467 try:
0468 value = p.typeConverter(value)
0469 except TypeError as e:
0470 raise TypeError('Invalid default param value given for param "%s". %s'
0471 % (p.name, e))
0472 self._defaultParamMap[p] = value
0473 return self
0474
0475 def _copyValues(self, to, extra=None):
0476 """
0477 Copies param values from this instance to another instance for
0478 params shared by them.
0479
0480 :param to: the target instance
0481 :param extra: extra params to be copied
0482 :return: the target instance with param values copied
0483 """
0484 paramMap = self._paramMap.copy()
0485 if isinstance(extra, dict):
0486 for param, value in extra.items():
0487 if isinstance(param, Param):
0488 paramMap[param] = value
0489 else:
0490 raise TypeError("Expecting a valid instance of Param, but received: {}"
0491 .format(param))
0492 elif extra is not None:
0493 raise TypeError("Expecting a dict, but received an object of type {}."
0494 .format(type(extra)))
0495 for param in self.params:
0496
0497 if param in self._defaultParamMap and to.hasParam(param.name):
0498 to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param]
0499
0500 if param in paramMap and to.hasParam(param.name):
0501 to._set(**{param.name: paramMap[param]})
0502 return to
0503
0504 def _resetUid(self, newUid):
0505 """
0506 Changes the uid of this instance. This updates both
0507 the stored uid and the parent uid of params and param maps.
0508 This is used by persistence (loading).
0509 :param newUid: new uid to use, which is converted to unicode
0510 :return: same instance, but with the uid and Param.parent values
0511 updated, including within param maps
0512 """
0513 newUid = unicode(newUid)
0514 self.uid = newUid
0515 newDefaultParamMap = dict()
0516 newParamMap = dict()
0517 for param in self.params:
0518 newParam = copy.copy(param)
0519 newParam.parent = newUid
0520 if param in self._defaultParamMap:
0521 newDefaultParamMap[newParam] = self._defaultParamMap[param]
0522 if param in self._paramMap:
0523 newParamMap[newParam] = self._paramMap[param]
0524 param.parent = newUid
0525 self._defaultParamMap = newDefaultParamMap
0526 self._paramMap = newParamMap
0527 return self