0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from abc import ABCMeta, abstractmethod
0019 import sys
0020 if sys.version >= '3':
0021 xrange = range
0022
0023 from pyspark import since
0024 from pyspark import SparkContext
0025 from pyspark.sql import DataFrame
0026 from pyspark.ml import Estimator, Transformer, Model
0027 from pyspark.ml.param import Params
0028 from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol
0029 from pyspark.ml.util import _jvm
0030 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0031
0032
0033 class JavaWrapper(object):
0034 """
0035 Wrapper class for a Java companion object
0036 """
0037 def __init__(self, java_obj=None):
0038 super(JavaWrapper, self).__init__()
0039 self._java_obj = java_obj
0040
0041 def __del__(self):
0042 if SparkContext._active_spark_context and self._java_obj is not None:
0043 SparkContext._active_spark_context._gateway.detach(self._java_obj)
0044
0045 @classmethod
0046 def _create_from_java_class(cls, java_class, *args):
0047 """
0048 Construct this object from given Java classname and arguments
0049 """
0050 java_obj = JavaWrapper._new_java_obj(java_class, *args)
0051 return cls(java_obj)
0052
0053 def _call_java(self, name, *args):
0054 m = getattr(self._java_obj, name)
0055 sc = SparkContext._active_spark_context
0056 java_args = [_py2java(sc, arg) for arg in args]
0057 return _java2py(sc, m(*java_args))
0058
0059 @staticmethod
0060 def _new_java_obj(java_class, *args):
0061 """
0062 Returns a new Java object.
0063 """
0064 sc = SparkContext._active_spark_context
0065 java_obj = _jvm()
0066 for name in java_class.split("."):
0067 java_obj = getattr(java_obj, name)
0068 java_args = [_py2java(sc, arg) for arg in args]
0069 return java_obj(*java_args)
0070
0071 @staticmethod
0072 def _new_java_array(pylist, java_class):
0073 """
0074 Create a Java array of given java_class type. Useful for
0075 calling a method with a Scala Array from Python with Py4J.
0076 If the param pylist is a 2D array, then a 2D java array will be returned.
0077 The returned 2D java array is a square, non-jagged 2D array that is big
0078 enough for all elements. The empty slots in the inner Java arrays will
0079 be filled with null to make the non-jagged 2D array.
0080
0081 :param pylist:
0082 Python list to convert to a Java Array.
0083 :param java_class:
0084 Java class to specify the type of Array. Should be in the
0085 form of sc._gateway.jvm.* (sc is a valid Spark Context).
0086 :return:
0087 Java Array of converted pylist.
0088
0089 Example primitive Java classes:
0090 - basestring -> sc._gateway.jvm.java.lang.String
0091 - int -> sc._gateway.jvm.java.lang.Integer
0092 - float -> sc._gateway.jvm.java.lang.Double
0093 - bool -> sc._gateway.jvm.java.lang.Boolean
0094 """
0095 sc = SparkContext._active_spark_context
0096 java_array = None
0097 if len(pylist) > 0 and isinstance(pylist[0], list):
0098
0099
0100 inner_array_length = 0
0101 for i in xrange(len(pylist)):
0102 inner_array_length = max(inner_array_length, len(pylist[i]))
0103 java_array = sc._gateway.new_array(java_class, len(pylist), inner_array_length)
0104 for i in xrange(len(pylist)):
0105 for j in xrange(len(pylist[i])):
0106 java_array[i][j] = pylist[i][j]
0107 else:
0108 java_array = sc._gateway.new_array(java_class, len(pylist))
0109 for i in xrange(len(pylist)):
0110 java_array[i] = pylist[i]
0111 return java_array
0112
0113
0114 @inherit_doc
0115 class JavaParams(JavaWrapper, Params):
0116 """
0117 Utility class to help create wrapper classes from Java/Scala
0118 implementations of pipeline components.
0119 """
0120
0121
0122
0123 __metaclass__ = ABCMeta
0124
0125 def _make_java_param_pair(self, param, value):
0126 """
0127 Makes a Java param pair.
0128 """
0129 sc = SparkContext._active_spark_context
0130 param = self._resolveParam(param)
0131 java_param = self._java_obj.getParam(param.name)
0132 java_value = _py2java(sc, value)
0133 return java_param.w(java_value)
0134
0135 def _transfer_params_to_java(self):
0136 """
0137 Transforms the embedded params to the companion Java object.
0138 """
0139 pair_defaults = []
0140 for param in self.params:
0141 if self.isSet(param):
0142 pair = self._make_java_param_pair(param, self._paramMap[param])
0143 self._java_obj.set(pair)
0144 if self.hasDefault(param):
0145 pair = self._make_java_param_pair(param, self._defaultParamMap[param])
0146 pair_defaults.append(pair)
0147 if len(pair_defaults) > 0:
0148 sc = SparkContext._active_spark_context
0149 pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
0150 self._java_obj.setDefault(pair_defaults_seq)
0151
0152 def _transfer_param_map_to_java(self, pyParamMap):
0153 """
0154 Transforms a Python ParamMap into a Java ParamMap.
0155 """
0156 paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
0157 for param in self.params:
0158 if param in pyParamMap:
0159 pair = self._make_java_param_pair(param, pyParamMap[param])
0160 paramMap.put([pair])
0161 return paramMap
0162
0163 def _create_params_from_java(self):
0164 """
0165 SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here
0166 """
0167 java_params = list(self._java_obj.params())
0168 from pyspark.ml.param import Param
0169 for java_param in java_params:
0170 java_param_name = java_param.name()
0171 if not hasattr(self, java_param_name):
0172 param = Param(self, java_param_name, java_param.doc())
0173 setattr(param, "created_from_java_param", True)
0174 setattr(self, java_param_name, param)
0175 self._params = None
0176
0177 def _transfer_params_from_java(self):
0178 """
0179 Transforms the embedded params from the companion Java object.
0180 """
0181 sc = SparkContext._active_spark_context
0182 for param in self.params:
0183 if self._java_obj.hasParam(param.name):
0184 java_param = self._java_obj.getParam(param.name)
0185
0186 if self._java_obj.isSet(java_param):
0187 value = _java2py(sc, self._java_obj.getOrDefault(java_param))
0188 self._set(**{param.name: value})
0189
0190 if self._java_obj.hasDefault(java_param) and not self.isDefined(param):
0191 value = _java2py(sc, self._java_obj.getDefault(java_param)).get()
0192 self._setDefault(**{param.name: value})
0193
0194 def _transfer_param_map_from_java(self, javaParamMap):
0195 """
0196 Transforms a Java ParamMap into a Python ParamMap.
0197 """
0198 sc = SparkContext._active_spark_context
0199 paramMap = dict()
0200 for pair in javaParamMap.toList():
0201 param = pair.param()
0202 if self.hasParam(str(param.name())):
0203 paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
0204 return paramMap
0205
0206 @staticmethod
0207 def _empty_java_param_map():
0208 """
0209 Returns an empty Java ParamMap reference.
0210 """
0211 return _jvm().org.apache.spark.ml.param.ParamMap()
0212
0213 def _to_java(self):
0214 """
0215 Transfer this instance's Params to the wrapped Java object, and return the Java object.
0216 Used for ML persistence.
0217
0218 Meta-algorithms such as Pipeline should override this method.
0219
0220 :return: Java object equivalent to this instance.
0221 """
0222 self._transfer_params_to_java()
0223 return self._java_obj
0224
0225 @staticmethod
0226 def _from_java(java_stage):
0227 """
0228 Given a Java object, create and return a Python wrapper of it.
0229 Used for ML persistence.
0230
0231 Meta-algorithms such as Pipeline should override this method as a classmethod.
0232 """
0233 def __get_class(clazz):
0234 """
0235 Loads Python class from its name.
0236 """
0237 parts = clazz.split('.')
0238 module = ".".join(parts[:-1])
0239 m = __import__(module)
0240 for comp in parts[1:]:
0241 m = getattr(m, comp)
0242 return m
0243 stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
0244
0245 py_type = __get_class(stage_name)
0246 if issubclass(py_type, JavaParams):
0247
0248 py_stage = py_type()
0249 py_stage._java_obj = java_stage
0250
0251
0252 if issubclass(py_type, JavaModel):
0253 py_stage._create_params_from_java()
0254
0255 py_stage._resetUid(java_stage.uid())
0256 py_stage._transfer_params_from_java()
0257 elif hasattr(py_type, "_from_java"):
0258 py_stage = py_type._from_java(java_stage)
0259 else:
0260 raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
0261 % stage_name)
0262 return py_stage
0263
0264 def copy(self, extra=None):
0265 """
0266 Creates a copy of this instance with the same uid and some
0267 extra params. This implementation first calls Params.copy and
0268 then make a copy of the companion Java pipeline component with
0269 extra params. So both the Python wrapper and the Java pipeline
0270 component get copied.
0271
0272 :param extra: Extra parameters to copy to the new instance
0273 :return: Copy of this instance
0274 """
0275 if extra is None:
0276 extra = dict()
0277 that = super(JavaParams, self).copy(extra)
0278 if self._java_obj is not None:
0279 that._java_obj = self._java_obj.copy(self._empty_java_param_map())
0280 that._transfer_params_to_java()
0281 return that
0282
0283 def clear(self, param):
0284 """
0285 Clears a param from the param map if it has been explicitly set.
0286 """
0287 super(JavaParams, self).clear(param)
0288 java_param = self._java_obj.getParam(param.name)
0289 self._java_obj.clear(java_param)
0290
0291
0292 @inherit_doc
0293 class JavaEstimator(JavaParams, Estimator):
0294 """
0295 Base class for :py:class:`Estimator`s that wrap Java/Scala
0296 implementations.
0297 """
0298
0299 __metaclass__ = ABCMeta
0300
0301 @abstractmethod
0302 def _create_model(self, java_model):
0303 """
0304 Creates a model from the input Java model reference.
0305 """
0306 raise NotImplementedError()
0307
0308 def _fit_java(self, dataset):
0309 """
0310 Fits a Java model to the input dataset.
0311
0312 :param dataset: input dataset, which is an instance of
0313 :py:class:`pyspark.sql.DataFrame`
0314 :param params: additional params (overwriting embedded values)
0315 :return: fitted Java model
0316 """
0317 self._transfer_params_to_java()
0318 return self._java_obj.fit(dataset._jdf)
0319
0320 def _fit(self, dataset):
0321 java_model = self._fit_java(dataset)
0322 model = self._create_model(java_model)
0323 return self._copyValues(model)
0324
0325
0326 @inherit_doc
0327 class JavaTransformer(JavaParams, Transformer):
0328 """
0329 Base class for :py:class:`Transformer`s that wrap Java/Scala
0330 implementations. Subclasses should ensure they have the transformer Java object
0331 available as _java_obj.
0332 """
0333
0334 __metaclass__ = ABCMeta
0335
0336 def _transform(self, dataset):
0337 self._transfer_params_to_java()
0338 return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
0339
0340
0341 @inherit_doc
0342 class JavaModel(JavaTransformer, Model):
0343 """
0344 Base class for :py:class:`Model`s that wrap Java/Scala
0345 implementations. Subclasses should inherit this class before
0346 param mix-ins, because this sets the UID from the Java model.
0347 """
0348
0349 __metaclass__ = ABCMeta
0350
0351 def __init__(self, java_model=None):
0352 """
0353 Initialize this instance with a Java model object.
0354 Subclasses should call this constructor, initialize params,
0355 and then call _transfer_params_from_java.
0356
0357 This instance can be instantiated without specifying java_model,
0358 it will be assigned after that, but this scenario only used by
0359 :py:class:`JavaMLReader` to load models. This is a bit of a
0360 hack, but it is easiest since a proper fix would require
0361 MLReader (in pyspark.ml.util) to depend on these wrappers, but
0362 these wrappers depend on pyspark.ml.util (both directly and via
0363 other ML classes).
0364 """
0365 super(JavaModel, self).__init__(java_model)
0366 if java_model is not None:
0367
0368
0369
0370
0371 self._create_params_from_java()
0372
0373 self._resetUid(java_model.uid())
0374
0375 def __repr__(self):
0376 return self._call_java("toString")
0377
0378
0379 @inherit_doc
0380 class _JavaPredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
0381 """
0382 Params for :py:class:`JavaPredictor` and :py:class:`JavaPredictorModel`.
0383
0384 .. versionadded:: 3.0.0
0385 """
0386 pass
0387
0388
0389 @inherit_doc
0390 class JavaPredictor(JavaEstimator, _JavaPredictorParams):
0391 """
0392 (Private) Java Estimator for prediction tasks (regression and classification).
0393 """
0394
0395 @since("3.0.0")
0396 def setLabelCol(self, value):
0397 """
0398 Sets the value of :py:attr:`labelCol`.
0399 """
0400 return self._set(labelCol=value)
0401
0402 @since("3.0.0")
0403 def setFeaturesCol(self, value):
0404 """
0405 Sets the value of :py:attr:`featuresCol`.
0406 """
0407 return self._set(featuresCol=value)
0408
0409 @since("3.0.0")
0410 def setPredictionCol(self, value):
0411 """
0412 Sets the value of :py:attr:`predictionCol`.
0413 """
0414 return self._set(predictionCol=value)
0415
0416
0417 @inherit_doc
0418 class JavaPredictionModel(JavaModel, _JavaPredictorParams):
0419 """
0420 (Private) Java Model for prediction tasks (regression and classification).
0421 """
0422
0423 @since("3.0.0")
0424 def setFeaturesCol(self, value):
0425 """
0426 Sets the value of :py:attr:`featuresCol`.
0427 """
0428 return self._set(featuresCol=value)
0429
0430 @since("3.0.0")
0431 def setPredictionCol(self, value):
0432 """
0433 Sets the value of :py:attr:`predictionCol`.
0434 """
0435 return self._set(predictionCol=value)
0436
0437 @property
0438 @since("2.1.0")
0439 def numFeatures(self):
0440 """
0441 Returns the number of features the model was trained on. If unknown, returns -1
0442 """
0443 return self._call_java("numFeatures")
0444
0445 @since("3.0.0")
0446 def predict(self, value):
0447 """
0448 Predict label for the given features.
0449 """
0450 return self._call_java("predict", value)