Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from abc import ABCMeta, abstractmethod
0019 import sys
0020 if sys.version >= '3':
0021     xrange = range
0022 
0023 from pyspark import since
0024 from pyspark import SparkContext
0025 from pyspark.sql import DataFrame
0026 from pyspark.ml import Estimator, Transformer, Model
0027 from pyspark.ml.param import Params
0028 from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol
0029 from pyspark.ml.util import _jvm
0030 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0031 
0032 
0033 class JavaWrapper(object):
0034     """
0035     Wrapper class for a Java companion object
0036     """
0037     def __init__(self, java_obj=None):
0038         super(JavaWrapper, self).__init__()
0039         self._java_obj = java_obj
0040 
0041     def __del__(self):
0042         if SparkContext._active_spark_context and self._java_obj is not None:
0043             SparkContext._active_spark_context._gateway.detach(self._java_obj)
0044 
0045     @classmethod
0046     def _create_from_java_class(cls, java_class, *args):
0047         """
0048         Construct this object from given Java classname and arguments
0049         """
0050         java_obj = JavaWrapper._new_java_obj(java_class, *args)
0051         return cls(java_obj)
0052 
0053     def _call_java(self, name, *args):
0054         m = getattr(self._java_obj, name)
0055         sc = SparkContext._active_spark_context
0056         java_args = [_py2java(sc, arg) for arg in args]
0057         return _java2py(sc, m(*java_args))
0058 
0059     @staticmethod
0060     def _new_java_obj(java_class, *args):
0061         """
0062         Returns a new Java object.
0063         """
0064         sc = SparkContext._active_spark_context
0065         java_obj = _jvm()
0066         for name in java_class.split("."):
0067             java_obj = getattr(java_obj, name)
0068         java_args = [_py2java(sc, arg) for arg in args]
0069         return java_obj(*java_args)
0070 
0071     @staticmethod
0072     def _new_java_array(pylist, java_class):
0073         """
0074         Create a Java array of given java_class type. Useful for
0075         calling a method with a Scala Array from Python with Py4J.
0076         If the param pylist is a 2D array, then a 2D java array will be returned.
0077         The returned 2D java array is a square, non-jagged 2D array that is big
0078         enough for all elements. The empty slots in the inner Java arrays will
0079         be filled with null to make the non-jagged 2D array.
0080 
0081         :param pylist:
0082           Python list to convert to a Java Array.
0083         :param java_class:
0084           Java class to specify the type of Array. Should be in the
0085           form of sc._gateway.jvm.* (sc is a valid Spark Context).
0086         :return:
0087           Java Array of converted pylist.
0088 
0089         Example primitive Java classes:
0090           - basestring -> sc._gateway.jvm.java.lang.String
0091           - int -> sc._gateway.jvm.java.lang.Integer
0092           - float -> sc._gateway.jvm.java.lang.Double
0093           - bool -> sc._gateway.jvm.java.lang.Boolean
0094         """
0095         sc = SparkContext._active_spark_context
0096         java_array = None
0097         if len(pylist) > 0 and isinstance(pylist[0], list):
0098             # If pylist is a 2D array, then a 2D java array will be created.
0099             # The 2D array is a square, non-jagged 2D array that is big enough for all elements.
0100             inner_array_length = 0
0101             for i in xrange(len(pylist)):
0102                 inner_array_length = max(inner_array_length, len(pylist[i]))
0103             java_array = sc._gateway.new_array(java_class, len(pylist), inner_array_length)
0104             for i in xrange(len(pylist)):
0105                 for j in xrange(len(pylist[i])):
0106                     java_array[i][j] = pylist[i][j]
0107         else:
0108             java_array = sc._gateway.new_array(java_class, len(pylist))
0109             for i in xrange(len(pylist)):
0110                 java_array[i] = pylist[i]
0111         return java_array
0112 
0113 
0114 @inherit_doc
0115 class JavaParams(JavaWrapper, Params):
0116     """
0117     Utility class to help create wrapper classes from Java/Scala
0118     implementations of pipeline components.
0119     """
0120     #: The param values in the Java object should be
0121     #: synced with the Python wrapper in fit/transform/evaluate/copy.
0122 
0123     __metaclass__ = ABCMeta
0124 
0125     def _make_java_param_pair(self, param, value):
0126         """
0127         Makes a Java param pair.
0128         """
0129         sc = SparkContext._active_spark_context
0130         param = self._resolveParam(param)
0131         java_param = self._java_obj.getParam(param.name)
0132         java_value = _py2java(sc, value)
0133         return java_param.w(java_value)
0134 
0135     def _transfer_params_to_java(self):
0136         """
0137         Transforms the embedded params to the companion Java object.
0138         """
0139         pair_defaults = []
0140         for param in self.params:
0141             if self.isSet(param):
0142                 pair = self._make_java_param_pair(param, self._paramMap[param])
0143                 self._java_obj.set(pair)
0144             if self.hasDefault(param):
0145                 pair = self._make_java_param_pair(param, self._defaultParamMap[param])
0146                 pair_defaults.append(pair)
0147         if len(pair_defaults) > 0:
0148             sc = SparkContext._active_spark_context
0149             pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
0150             self._java_obj.setDefault(pair_defaults_seq)
0151 
0152     def _transfer_param_map_to_java(self, pyParamMap):
0153         """
0154         Transforms a Python ParamMap into a Java ParamMap.
0155         """
0156         paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
0157         for param in self.params:
0158             if param in pyParamMap:
0159                 pair = self._make_java_param_pair(param, pyParamMap[param])
0160                 paramMap.put([pair])
0161         return paramMap
0162 
0163     def _create_params_from_java(self):
0164         """
0165         SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here
0166         """
0167         java_params = list(self._java_obj.params())
0168         from pyspark.ml.param import Param
0169         for java_param in java_params:
0170             java_param_name = java_param.name()
0171             if not hasattr(self, java_param_name):
0172                 param = Param(self, java_param_name, java_param.doc())
0173                 setattr(param, "created_from_java_param", True)
0174                 setattr(self, java_param_name, param)
0175                 self._params = None  # need to reset so self.params will discover new params
0176 
0177     def _transfer_params_from_java(self):
0178         """
0179         Transforms the embedded params from the companion Java object.
0180         """
0181         sc = SparkContext._active_spark_context
0182         for param in self.params:
0183             if self._java_obj.hasParam(param.name):
0184                 java_param = self._java_obj.getParam(param.name)
0185                 # SPARK-14931: Only check set params back to avoid default params mismatch.
0186                 if self._java_obj.isSet(java_param):
0187                     value = _java2py(sc, self._java_obj.getOrDefault(java_param))
0188                     self._set(**{param.name: value})
0189                 # SPARK-10931: Temporary fix for params that have a default in Java
0190                 if self._java_obj.hasDefault(java_param) and not self.isDefined(param):
0191                     value = _java2py(sc, self._java_obj.getDefault(java_param)).get()
0192                     self._setDefault(**{param.name: value})
0193 
0194     def _transfer_param_map_from_java(self, javaParamMap):
0195         """
0196         Transforms a Java ParamMap into a Python ParamMap.
0197         """
0198         sc = SparkContext._active_spark_context
0199         paramMap = dict()
0200         for pair in javaParamMap.toList():
0201             param = pair.param()
0202             if self.hasParam(str(param.name())):
0203                 paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
0204         return paramMap
0205 
0206     @staticmethod
0207     def _empty_java_param_map():
0208         """
0209         Returns an empty Java ParamMap reference.
0210         """
0211         return _jvm().org.apache.spark.ml.param.ParamMap()
0212 
0213     def _to_java(self):
0214         """
0215         Transfer this instance's Params to the wrapped Java object, and return the Java object.
0216         Used for ML persistence.
0217 
0218         Meta-algorithms such as Pipeline should override this method.
0219 
0220         :return: Java object equivalent to this instance.
0221         """
0222         self._transfer_params_to_java()
0223         return self._java_obj
0224 
0225     @staticmethod
0226     def _from_java(java_stage):
0227         """
0228         Given a Java object, create and return a Python wrapper of it.
0229         Used for ML persistence.
0230 
0231         Meta-algorithms such as Pipeline should override this method as a classmethod.
0232         """
0233         def __get_class(clazz):
0234             """
0235             Loads Python class from its name.
0236             """
0237             parts = clazz.split('.')
0238             module = ".".join(parts[:-1])
0239             m = __import__(module)
0240             for comp in parts[1:]:
0241                 m = getattr(m, comp)
0242             return m
0243         stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
0244         # Generate a default new instance from the stage_name class.
0245         py_type = __get_class(stage_name)
0246         if issubclass(py_type, JavaParams):
0247             # Load information from java_stage to the instance.
0248             py_stage = py_type()
0249             py_stage._java_obj = java_stage
0250 
0251             # SPARK-10931: Temporary fix so that persisted models would own params from Estimator
0252             if issubclass(py_type, JavaModel):
0253                 py_stage._create_params_from_java()
0254 
0255             py_stage._resetUid(java_stage.uid())
0256             py_stage._transfer_params_from_java()
0257         elif hasattr(py_type, "_from_java"):
0258             py_stage = py_type._from_java(java_stage)
0259         else:
0260             raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
0261                                       % stage_name)
0262         return py_stage
0263 
0264     def copy(self, extra=None):
0265         """
0266         Creates a copy of this instance with the same uid and some
0267         extra params. This implementation first calls Params.copy and
0268         then make a copy of the companion Java pipeline component with
0269         extra params. So both the Python wrapper and the Java pipeline
0270         component get copied.
0271 
0272         :param extra: Extra parameters to copy to the new instance
0273         :return: Copy of this instance
0274         """
0275         if extra is None:
0276             extra = dict()
0277         that = super(JavaParams, self).copy(extra)
0278         if self._java_obj is not None:
0279             that._java_obj = self._java_obj.copy(self._empty_java_param_map())
0280             that._transfer_params_to_java()
0281         return that
0282 
0283     def clear(self, param):
0284         """
0285         Clears a param from the param map if it has been explicitly set.
0286         """
0287         super(JavaParams, self).clear(param)
0288         java_param = self._java_obj.getParam(param.name)
0289         self._java_obj.clear(java_param)
0290 
0291 
0292 @inherit_doc
0293 class JavaEstimator(JavaParams, Estimator):
0294     """
0295     Base class for :py:class:`Estimator`s that wrap Java/Scala
0296     implementations.
0297     """
0298 
0299     __metaclass__ = ABCMeta
0300 
0301     @abstractmethod
0302     def _create_model(self, java_model):
0303         """
0304         Creates a model from the input Java model reference.
0305         """
0306         raise NotImplementedError()
0307 
0308     def _fit_java(self, dataset):
0309         """
0310         Fits a Java model to the input dataset.
0311 
0312         :param dataset: input dataset, which is an instance of
0313                         :py:class:`pyspark.sql.DataFrame`
0314         :param params: additional params (overwriting embedded values)
0315         :return: fitted Java model
0316         """
0317         self._transfer_params_to_java()
0318         return self._java_obj.fit(dataset._jdf)
0319 
0320     def _fit(self, dataset):
0321         java_model = self._fit_java(dataset)
0322         model = self._create_model(java_model)
0323         return self._copyValues(model)
0324 
0325 
0326 @inherit_doc
0327 class JavaTransformer(JavaParams, Transformer):
0328     """
0329     Base class for :py:class:`Transformer`s that wrap Java/Scala
0330     implementations. Subclasses should ensure they have the transformer Java object
0331     available as _java_obj.
0332     """
0333 
0334     __metaclass__ = ABCMeta
0335 
0336     def _transform(self, dataset):
0337         self._transfer_params_to_java()
0338         return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
0339 
0340 
0341 @inherit_doc
0342 class JavaModel(JavaTransformer, Model):
0343     """
0344     Base class for :py:class:`Model`s that wrap Java/Scala
0345     implementations. Subclasses should inherit this class before
0346     param mix-ins, because this sets the UID from the Java model.
0347     """
0348 
0349     __metaclass__ = ABCMeta
0350 
0351     def __init__(self, java_model=None):
0352         """
0353         Initialize this instance with a Java model object.
0354         Subclasses should call this constructor, initialize params,
0355         and then call _transfer_params_from_java.
0356 
0357         This instance can be instantiated without specifying java_model,
0358         it will be assigned after that, but this scenario only used by
0359         :py:class:`JavaMLReader` to load models.  This is a bit of a
0360         hack, but it is easiest since a proper fix would require
0361         MLReader (in pyspark.ml.util) to depend on these wrappers, but
0362         these wrappers depend on pyspark.ml.util (both directly and via
0363         other ML classes).
0364         """
0365         super(JavaModel, self).__init__(java_model)
0366         if java_model is not None:
0367 
0368             # SPARK-10931: This is a temporary fix to allow models to own params
0369             # from estimators. Eventually, these params should be in models through
0370             # using common base classes between estimators and models.
0371             self._create_params_from_java()
0372 
0373             self._resetUid(java_model.uid())
0374 
0375     def __repr__(self):
0376         return self._call_java("toString")
0377 
0378 
0379 @inherit_doc
0380 class _JavaPredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
0381     """
0382     Params for :py:class:`JavaPredictor` and :py:class:`JavaPredictorModel`.
0383 
0384     .. versionadded:: 3.0.0
0385     """
0386     pass
0387 
0388 
0389 @inherit_doc
0390 class JavaPredictor(JavaEstimator, _JavaPredictorParams):
0391     """
0392     (Private) Java Estimator for prediction tasks (regression and classification).
0393     """
0394 
0395     @since("3.0.0")
0396     def setLabelCol(self, value):
0397         """
0398         Sets the value of :py:attr:`labelCol`.
0399         """
0400         return self._set(labelCol=value)
0401 
0402     @since("3.0.0")
0403     def setFeaturesCol(self, value):
0404         """
0405         Sets the value of :py:attr:`featuresCol`.
0406         """
0407         return self._set(featuresCol=value)
0408 
0409     @since("3.0.0")
0410     def setPredictionCol(self, value):
0411         """
0412         Sets the value of :py:attr:`predictionCol`.
0413         """
0414         return self._set(predictionCol=value)
0415 
0416 
0417 @inherit_doc
0418 class JavaPredictionModel(JavaModel, _JavaPredictorParams):
0419     """
0420     (Private) Java Model for prediction tasks (regression and classification).
0421     """
0422 
0423     @since("3.0.0")
0424     def setFeaturesCol(self, value):
0425         """
0426         Sets the value of :py:attr:`featuresCol`.
0427         """
0428         return self._set(featuresCol=value)
0429 
0430     @since("3.0.0")
0431     def setPredictionCol(self, value):
0432         """
0433         Sets the value of :py:attr:`predictionCol`.
0434         """
0435         return self._set(predictionCol=value)
0436 
0437     @property
0438     @since("2.1.0")
0439     def numFeatures(self):
0440         """
0441         Returns the number of features the model was trained on. If unknown, returns -1
0442         """
0443         return self._call_java("numFeatures")
0444 
0445     @since("3.0.0")
0446     def predict(self, value):
0447         """
0448         Predict label for the given features.
0449         """
0450         return self._call_java("predict", value)