Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import array
0018 import sys
0019 if sys.version > '3':
0020     basestring = str
0021     xrange = range
0022     unicode = str
0023 
0024 from abc import ABCMeta
0025 import copy
0026 import numpy as np
0027 
0028 from py4j.java_gateway import JavaObject
0029 
0030 from pyspark.ml.linalg import DenseVector, Vector, Matrix
0031 from pyspark.ml.util import Identifiable
0032 
0033 
0034 __all__ = ['Param', 'Params', 'TypeConverters']
0035 
0036 
0037 class Param(object):
0038     """
0039     A param with self-contained documentation.
0040 
0041     .. versionadded:: 1.3.0
0042     """
0043 
0044     def __init__(self, parent, name, doc, typeConverter=None):
0045         if not isinstance(parent, Identifiable):
0046             raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
0047         self.parent = parent.uid
0048         self.name = str(name)
0049         self.doc = str(doc)
0050         self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter
0051 
0052     def _copy_new_parent(self, parent):
0053         """Copy the current param to a new parent, must be a dummy param."""
0054         if self.parent == "undefined":
0055             param = copy.copy(self)
0056             param.parent = parent.uid
0057             return param
0058         else:
0059             raise ValueError("Cannot copy from non-dummy parent %s." % parent)
0060 
0061     def __str__(self):
0062         return str(self.parent) + "__" + self.name
0063 
0064     def __repr__(self):
0065         return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
0066 
0067     def __hash__(self):
0068         return hash(str(self))
0069 
0070     def __eq__(self, other):
0071         if isinstance(other, Param):
0072             return self.parent == other.parent and self.name == other.name
0073         else:
0074             return False
0075 
0076 
0077 class TypeConverters(object):
0078     """
0079     Factory methods for common type conversion functions for `Param.typeConverter`.
0080 
0081     .. versionadded:: 2.0.0
0082     """
0083 
0084     @staticmethod
0085     def _is_numeric(value):
0086         vtype = type(value)
0087         return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
0088 
0089     @staticmethod
0090     def _is_integer(value):
0091         return TypeConverters._is_numeric(value) and float(value).is_integer()
0092 
0093     @staticmethod
0094     def _can_convert_to_list(value):
0095         vtype = type(value)
0096         return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector)
0097 
0098     @staticmethod
0099     def _can_convert_to_string(value):
0100         vtype = type(value)
0101         return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_]
0102 
0103     @staticmethod
0104     def identity(value):
0105         """
0106         Dummy converter that just returns value.
0107         """
0108         return value
0109 
0110     @staticmethod
0111     def toList(value):
0112         """
0113         Convert a value to a list, if possible.
0114         """
0115         if type(value) == list:
0116             return value
0117         elif type(value) in [np.ndarray, tuple, xrange, array.array]:
0118             return list(value)
0119         elif isinstance(value, Vector):
0120             return list(value.toArray())
0121         else:
0122             raise TypeError("Could not convert %s to list" % value)
0123 
0124     @staticmethod
0125     def toListFloat(value):
0126         """
0127         Convert a value to list of floats, if possible.
0128         """
0129         if TypeConverters._can_convert_to_list(value):
0130             value = TypeConverters.toList(value)
0131             if all(map(lambda v: TypeConverters._is_numeric(v), value)):
0132                 return [float(v) for v in value]
0133         raise TypeError("Could not convert %s to list of floats" % value)
0134 
0135     @staticmethod
0136     def toListListFloat(value):
0137         """
0138         Convert a value to list of list of floats, if possible.
0139         """
0140         if TypeConverters._can_convert_to_list(value):
0141             value = TypeConverters.toList(value)
0142             return [TypeConverters.toListFloat(v) for v in value]
0143         raise TypeError("Could not convert %s to list of list of floats" % value)
0144 
0145     @staticmethod
0146     def toListInt(value):
0147         """
0148         Convert a value to list of ints, if possible.
0149         """
0150         if TypeConverters._can_convert_to_list(value):
0151             value = TypeConverters.toList(value)
0152             if all(map(lambda v: TypeConverters._is_integer(v), value)):
0153                 return [int(v) for v in value]
0154         raise TypeError("Could not convert %s to list of ints" % value)
0155 
0156     @staticmethod
0157     def toListString(value):
0158         """
0159         Convert a value to list of strings, if possible.
0160         """
0161         if TypeConverters._can_convert_to_list(value):
0162             value = TypeConverters.toList(value)
0163             if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)):
0164                 return [TypeConverters.toString(v) for v in value]
0165         raise TypeError("Could not convert %s to list of strings" % value)
0166 
0167     @staticmethod
0168     def toVector(value):
0169         """
0170         Convert a value to a MLlib Vector, if possible.
0171         """
0172         if isinstance(value, Vector):
0173             return value
0174         elif TypeConverters._can_convert_to_list(value):
0175             value = TypeConverters.toList(value)
0176             if all(map(lambda v: TypeConverters._is_numeric(v), value)):
0177                 return DenseVector(value)
0178         raise TypeError("Could not convert %s to vector" % value)
0179 
0180     @staticmethod
0181     def toMatrix(value):
0182         """
0183         Convert a value to a MLlib Matrix, if possible.
0184         """
0185         if isinstance(value, Matrix):
0186             return value
0187         raise TypeError("Could not convert %s to matrix" % value)
0188 
0189     @staticmethod
0190     def toFloat(value):
0191         """
0192         Convert a value to a float, if possible.
0193         """
0194         if TypeConverters._is_numeric(value):
0195             return float(value)
0196         else:
0197             raise TypeError("Could not convert %s to float" % value)
0198 
0199     @staticmethod
0200     def toInt(value):
0201         """
0202         Convert a value to an int, if possible.
0203         """
0204         if TypeConverters._is_integer(value):
0205             return int(value)
0206         else:
0207             raise TypeError("Could not convert %s to int" % value)
0208 
0209     @staticmethod
0210     def toString(value):
0211         """
0212         Convert a value to a string, if possible.
0213         """
0214         if isinstance(value, basestring):
0215             return value
0216         elif type(value) in [np.string_, np.str_]:
0217             return str(value)
0218         elif type(value) == np.unicode_:
0219             return unicode(value)
0220         else:
0221             raise TypeError("Could not convert %s to string type" % type(value))
0222 
0223     @staticmethod
0224     def toBoolean(value):
0225         """
0226         Convert a value to a boolean, if possible.
0227         """
0228         if type(value) == bool:
0229             return value
0230         else:
0231             raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value))
0232 
0233 
0234 class Params(Identifiable):
0235     """
0236     Components that take parameters. This also provides an internal
0237     param map to store parameter values attached to the instance.
0238 
0239     .. versionadded:: 1.3.0
0240     """
0241 
0242     __metaclass__ = ABCMeta
0243 
0244     def __init__(self):
0245         super(Params, self).__init__()
0246         #: internal param map for user-supplied values param map
0247         self._paramMap = {}
0248 
0249         #: internal param map for default values
0250         self._defaultParamMap = {}
0251 
0252         #: value returned by :py:func:`params`
0253         self._params = None
0254 
0255         # Copy the params from the class to the object
0256         self._copy_params()
0257 
0258     def _copy_params(self):
0259         """
0260         Copy all params defined on the class to current object.
0261         """
0262         cls = type(self)
0263         src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)]
0264         src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs))
0265         for name, param in src_params:
0266             setattr(self, name, param._copy_new_parent(self))
0267 
0268     @property
0269     def params(self):
0270         """
0271         Returns all params ordered by name. The default implementation
0272         uses :py:func:`dir` to get all attributes of type
0273         :py:class:`Param`.
0274         """
0275         if self._params is None:
0276             self._params = list(filter(lambda attr: isinstance(attr, Param),
0277                                        [getattr(self, x) for x in dir(self) if x != "params" and
0278                                         not isinstance(getattr(type(self), x, None), property)]))
0279         return self._params
0280 
0281     def explainParam(self, param):
0282         """
0283         Explains a single param and returns its name, doc, and optional
0284         default value and user-supplied value in a string.
0285         """
0286         param = self._resolveParam(param)
0287         values = []
0288         if self.isDefined(param):
0289             if param in self._defaultParamMap:
0290                 values.append("default: %s" % self._defaultParamMap[param])
0291             if param in self._paramMap:
0292                 values.append("current: %s" % self._paramMap[param])
0293         else:
0294             values.append("undefined")
0295         valueStr = "(" + ", ".join(values) + ")"
0296         return "%s: %s %s" % (param.name, param.doc, valueStr)
0297 
0298     def explainParams(self):
0299         """
0300         Returns the documentation of all params with their optionally
0301         default values and user-supplied values.
0302         """
0303         return "\n".join([self.explainParam(param) for param in self.params])
0304 
0305     def getParam(self, paramName):
0306         """
0307         Gets a param by its name.
0308         """
0309         param = getattr(self, paramName)
0310         if isinstance(param, Param):
0311             return param
0312         else:
0313             raise ValueError("Cannot find param with name %s." % paramName)
0314 
0315     def isSet(self, param):
0316         """
0317         Checks whether a param is explicitly set by user.
0318         """
0319         param = self._resolveParam(param)
0320         return param in self._paramMap
0321 
0322     def hasDefault(self, param):
0323         """
0324         Checks whether a param has a default value.
0325         """
0326         param = self._resolveParam(param)
0327         return param in self._defaultParamMap
0328 
0329     def isDefined(self, param):
0330         """
0331         Checks whether a param is explicitly set by user or has
0332         a default value.
0333         """
0334         return self.isSet(param) or self.hasDefault(param)
0335 
0336     def hasParam(self, paramName):
0337         """
0338         Tests whether this instance contains a param with a given
0339         (string) name.
0340         """
0341         if isinstance(paramName, basestring):
0342             p = getattr(self, paramName, None)
0343             return isinstance(p, Param)
0344         else:
0345             raise TypeError("hasParam(): paramName must be a string")
0346 
0347     def getOrDefault(self, param):
0348         """
0349         Gets the value of a param in the user-supplied param map or its
0350         default value. Raises an error if neither is set.
0351         """
0352         param = self._resolveParam(param)
0353         if param in self._paramMap:
0354             return self._paramMap[param]
0355         else:
0356             return self._defaultParamMap[param]
0357 
0358     def extractParamMap(self, extra=None):
0359         """
0360         Extracts the embedded default param values and user-supplied
0361         values, and then merges them with extra values from input into
0362         a flat param map, where the latter value is used if there exist
0363         conflicts, i.e., with ordering: default param values <
0364         user-supplied values < extra.
0365 
0366         :param extra: extra param values
0367         :return: merged param map
0368         """
0369         if extra is None:
0370             extra = dict()
0371         paramMap = self._defaultParamMap.copy()
0372         paramMap.update(self._paramMap)
0373         paramMap.update(extra)
0374         return paramMap
0375 
0376     def copy(self, extra=None):
0377         """
0378         Creates a copy of this instance with the same uid and some
0379         extra params. The default implementation creates a
0380         shallow copy using :py:func:`copy.copy`, and then copies the
0381         embedded and extra parameters over and returns the copy.
0382         Subclasses should override this method if the default approach
0383         is not sufficient.
0384 
0385         :param extra: Extra parameters to copy to the new instance
0386         :return: Copy of this instance
0387         """
0388         if extra is None:
0389             extra = dict()
0390         that = copy.copy(self)
0391         that._paramMap = {}
0392         that._defaultParamMap = {}
0393         return self._copyValues(that, extra)
0394 
0395     def set(self, param, value):
0396         """
0397         Sets a parameter in the embedded param map.
0398         """
0399         self._shouldOwn(param)
0400         try:
0401             value = param.typeConverter(value)
0402         except ValueError as e:
0403             raise ValueError('Invalid param value given for param "%s". %s' % (param.name, e))
0404         self._paramMap[param] = value
0405 
0406     def _shouldOwn(self, param):
0407         """
0408         Validates that the input param belongs to this Params instance.
0409         """
0410         if not (self.uid == param.parent and self.hasParam(param.name)):
0411             raise ValueError("Param %r does not belong to %r." % (param, self))
0412 
0413     def _resolveParam(self, param):
0414         """
0415         Resolves a param and validates the ownership.
0416 
0417         :param param: param name or the param instance, which must
0418                       belong to this Params instance
0419         :return: resolved param instance
0420         """
0421         if isinstance(param, Param):
0422             self._shouldOwn(param)
0423             return param
0424         elif isinstance(param, basestring):
0425             return self.getParam(param)
0426         else:
0427             raise ValueError("Cannot resolve %r as a param." % param)
0428 
0429     @staticmethod
0430     def _dummy():
0431         """
0432         Returns a dummy Params instance used as a placeholder to
0433         generate docs.
0434         """
0435         dummy = Params()
0436         dummy.uid = "undefined"
0437         return dummy
0438 
0439     def _set(self, **kwargs):
0440         """
0441         Sets user-supplied params.
0442         """
0443         for param, value in kwargs.items():
0444             p = getattr(self, param)
0445             if value is not None:
0446                 try:
0447                     value = p.typeConverter(value)
0448                 except TypeError as e:
0449                     raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
0450             self._paramMap[p] = value
0451         return self
0452 
0453     def clear(self, param):
0454         """
0455         Clears a param from the param map if it has been explicitly set.
0456         """
0457         if self.isSet(param):
0458             del self._paramMap[param]
0459 
0460     def _setDefault(self, **kwargs):
0461         """
0462         Sets default params.
0463         """
0464         for param, value in kwargs.items():
0465             p = getattr(self, param)
0466             if value is not None and not isinstance(value, JavaObject):
0467                 try:
0468                     value = p.typeConverter(value)
0469                 except TypeError as e:
0470                     raise TypeError('Invalid default param value given for param "%s". %s'
0471                                     % (p.name, e))
0472             self._defaultParamMap[p] = value
0473         return self
0474 
0475     def _copyValues(self, to, extra=None):
0476         """
0477         Copies param values from this instance to another instance for
0478         params shared by them.
0479 
0480         :param to: the target instance
0481         :param extra: extra params to be copied
0482         :return: the target instance with param values copied
0483         """
0484         paramMap = self._paramMap.copy()
0485         if isinstance(extra, dict):
0486             for param, value in extra.items():
0487                 if isinstance(param, Param):
0488                     paramMap[param] = value
0489                 else:
0490                     raise TypeError("Expecting a valid instance of Param, but received: {}"
0491                                     .format(param))
0492         elif extra is not None:
0493             raise TypeError("Expecting a dict, but received an object of type {}."
0494                             .format(type(extra)))
0495         for param in self.params:
0496             # copy default params
0497             if param in self._defaultParamMap and to.hasParam(param.name):
0498                 to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param]
0499             # copy explicitly set params
0500             if param in paramMap and to.hasParam(param.name):
0501                 to._set(**{param.name: paramMap[param]})
0502         return to
0503 
0504     def _resetUid(self, newUid):
0505         """
0506         Changes the uid of this instance. This updates both
0507         the stored uid and the parent uid of params and param maps.
0508         This is used by persistence (loading).
0509         :param newUid: new uid to use, which is converted to unicode
0510         :return: same instance, but with the uid and Param.parent values
0511                  updated, including within param maps
0512         """
0513         newUid = unicode(newUid)
0514         self.uid = newUid
0515         newDefaultParamMap = dict()
0516         newParamMap = dict()
0517         for param in self.params:
0518             newParam = copy.copy(param)
0519             newParam.parent = newUid
0520             if param in self._defaultParamMap:
0521                 newDefaultParamMap[newParam] = self._defaultParamMap[param]
0522             if param in self._paramMap:
0523                 newParamMap[newParam] = self._paramMap[param]
0524             param.parent = newUid
0525         self._defaultParamMap = newDefaultParamMap
0526         self._paramMap = newParamMap
0527         return self