0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 import os
0020
0021 if sys.version > '3':
0022 basestring = str
0023
0024 from pyspark import since, keyword_only, SparkContext
0025 from pyspark.ml.base import Estimator, Model, Transformer
0026 from pyspark.ml.param import Param, Params
0027 from pyspark.ml.util import *
0028 from pyspark.ml.wrapper import JavaParams, JavaWrapper
0029 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0030
0031
0032 @inherit_doc
0033 class Pipeline(Estimator, MLReadable, MLWritable):
0034 """
0035 A simple pipeline, which acts as an estimator. A Pipeline consists
0036 of a sequence of stages, each of which is either an
0037 :py:class:`Estimator` or a :py:class:`Transformer`. When
0038 :py:meth:`Pipeline.fit` is called, the stages are executed in
0039 order. If a stage is an :py:class:`Estimator`, its
0040 :py:meth:`Estimator.fit` method will be called on the input
0041 dataset to fit a model. Then the model, which is a transformer,
0042 will be used to transform the dataset as the input to the next
0043 stage. If a stage is a :py:class:`Transformer`, its
0044 :py:meth:`Transformer.transform` method will be called to produce
0045 the dataset for the next stage. The fitted model from a
0046 :py:class:`Pipeline` is a :py:class:`PipelineModel`, which
0047 consists of fitted models and transformers, corresponding to the
0048 pipeline stages. If stages is an empty list, the pipeline acts as an
0049 identity transformer.
0050
0051 .. versionadded:: 1.3.0
0052 """
0053
0054 stages = Param(Params._dummy(), "stages", "a list of pipeline stages")
0055
0056 @keyword_only
0057 def __init__(self, stages=None):
0058 """
0059 __init__(self, stages=None)
0060 """
0061 super(Pipeline, self).__init__()
0062 kwargs = self._input_kwargs
0063 self.setParams(**kwargs)
0064
0065 @since("1.3.0")
0066 def setStages(self, value):
0067 """
0068 Set pipeline stages.
0069
0070 :param value: a list of transformers or estimators
0071 :return: the pipeline instance
0072 """
0073 return self._set(stages=value)
0074
0075 @since("1.3.0")
0076 def getStages(self):
0077 """
0078 Get pipeline stages.
0079 """
0080 return self.getOrDefault(self.stages)
0081
0082 @keyword_only
0083 @since("1.3.0")
0084 def setParams(self, stages=None):
0085 """
0086 setParams(self, stages=None)
0087 Sets params for Pipeline.
0088 """
0089 kwargs = self._input_kwargs
0090 return self._set(**kwargs)
0091
0092 def _fit(self, dataset):
0093 stages = self.getStages()
0094 for stage in stages:
0095 if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
0096 raise TypeError(
0097 "Cannot recognize a pipeline stage of type %s." % type(stage))
0098 indexOfLastEstimator = -1
0099 for i, stage in enumerate(stages):
0100 if isinstance(stage, Estimator):
0101 indexOfLastEstimator = i
0102 transformers = []
0103 for i, stage in enumerate(stages):
0104 if i <= indexOfLastEstimator:
0105 if isinstance(stage, Transformer):
0106 transformers.append(stage)
0107 dataset = stage.transform(dataset)
0108 else:
0109 model = stage.fit(dataset)
0110 transformers.append(model)
0111 if i < indexOfLastEstimator:
0112 dataset = model.transform(dataset)
0113 else:
0114 transformers.append(stage)
0115 return PipelineModel(transformers)
0116
0117 @since("1.4.0")
0118 def copy(self, extra=None):
0119 """
0120 Creates a copy of this instance.
0121
0122 :param extra: extra parameters
0123 :returns: new instance
0124 """
0125 if extra is None:
0126 extra = dict()
0127 that = Params.copy(self, extra)
0128 stages = [stage.copy(extra) for stage in that.getStages()]
0129 return that.setStages(stages)
0130
0131 @since("2.0.0")
0132 def write(self):
0133 """Returns an MLWriter instance for this ML instance."""
0134 allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
0135 if allStagesAreJava:
0136 return JavaMLWriter(self)
0137 return PipelineWriter(self)
0138
0139 @classmethod
0140 @since("2.0.0")
0141 def read(cls):
0142 """Returns an MLReader instance for this class."""
0143 return PipelineReader(cls)
0144
0145 @classmethod
0146 def _from_java(cls, java_stage):
0147 """
0148 Given a Java Pipeline, create and return a Python wrapper of it.
0149 Used for ML persistence.
0150 """
0151
0152 py_stage = cls()
0153
0154 py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
0155 py_stage.setStages(py_stages)
0156 py_stage._resetUid(java_stage.uid())
0157 return py_stage
0158
0159 def _to_java(self):
0160 """
0161 Transfer this instance to a Java Pipeline. Used for ML persistence.
0162
0163 :return: Java object equivalent to this instance.
0164 """
0165
0166 gateway = SparkContext._gateway
0167 cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
0168 java_stages = gateway.new_array(cls, len(self.getStages()))
0169 for idx, stage in enumerate(self.getStages()):
0170 java_stages[idx] = stage._to_java()
0171
0172 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
0173 _java_obj.setStages(java_stages)
0174
0175 return _java_obj
0176
0177 def _make_java_param_pair(self, param, value):
0178 """
0179 Makes a Java param pair.
0180 """
0181 sc = SparkContext._active_spark_context
0182 param = self._resolveParam(param)
0183 java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc)
0184 if isinstance(value, Params) and hasattr(value, "_to_java"):
0185
0186
0187
0188
0189
0190 java_value = value._to_java()
0191 else:
0192 java_value = _py2java(sc, value)
0193 return java_param.w(java_value)
0194
0195 def _transfer_param_map_to_java(self, pyParamMap):
0196 """
0197 Transforms a Python ParamMap into a Java ParamMap.
0198 """
0199 paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
0200 for param in self.params:
0201 if param in pyParamMap:
0202 pair = self._make_java_param_pair(param, pyParamMap[param])
0203 paramMap.put([pair])
0204 return paramMap
0205
0206 def _transfer_param_map_from_java(self, javaParamMap):
0207 """
0208 Transforms a Java ParamMap into a Python ParamMap.
0209 """
0210 sc = SparkContext._active_spark_context
0211 paramMap = dict()
0212 for pair in javaParamMap.toList():
0213 param = pair.param()
0214 if self.hasParam(str(param.name())):
0215 java_obj = pair.value()
0216 if sc._jvm.Class.forName("org.apache.spark.ml.PipelineStage").isInstance(java_obj):
0217
0218
0219
0220 py_obj = JavaParams._from_java(java_obj)
0221 else:
0222 py_obj = _java2py(sc, java_obj)
0223 paramMap[self.getParam(param.name())] = py_obj
0224 return paramMap
0225
0226
0227 @inherit_doc
0228 class PipelineWriter(MLWriter):
0229 """
0230 (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
0231 """
0232
0233 def __init__(self, instance):
0234 super(PipelineWriter, self).__init__()
0235 self.instance = instance
0236
0237 def saveImpl(self, path):
0238 stages = self.instance.getStages()
0239 PipelineSharedReadWrite.validateStages(stages)
0240 PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
0241
0242
0243 @inherit_doc
0244 class PipelineReader(MLReader):
0245 """
0246 (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
0247 """
0248
0249 def __init__(self, cls):
0250 super(PipelineReader, self).__init__()
0251 self.cls = cls
0252
0253 def load(self, path):
0254 metadata = DefaultParamsReader.loadMetadata(path, self.sc)
0255 if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
0256 return JavaMLReader(self.cls).load(path)
0257 else:
0258 uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
0259 return Pipeline(stages=stages)._resetUid(uid)
0260
0261
0262 @inherit_doc
0263 class PipelineModelWriter(MLWriter):
0264 """
0265 (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
0266 """
0267
0268 def __init__(self, instance):
0269 super(PipelineModelWriter, self).__init__()
0270 self.instance = instance
0271
0272 def saveImpl(self, path):
0273 stages = self.instance.stages
0274 PipelineSharedReadWrite.validateStages(stages)
0275 PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
0276
0277
0278 @inherit_doc
0279 class PipelineModelReader(MLReader):
0280 """
0281 (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
0282 """
0283
0284 def __init__(self, cls):
0285 super(PipelineModelReader, self).__init__()
0286 self.cls = cls
0287
0288 def load(self, path):
0289 metadata = DefaultParamsReader.loadMetadata(path, self.sc)
0290 if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
0291 return JavaMLReader(self.cls).load(path)
0292 else:
0293 uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
0294 return PipelineModel(stages=stages)._resetUid(uid)
0295
0296
0297 @inherit_doc
0298 class PipelineModel(Model, MLReadable, MLWritable):
0299 """
0300 Represents a compiled pipeline with transformers and fitted models.
0301
0302 .. versionadded:: 1.3.0
0303 """
0304
0305 def __init__(self, stages):
0306 super(PipelineModel, self).__init__()
0307 self.stages = stages
0308
0309 def _transform(self, dataset):
0310 for t in self.stages:
0311 dataset = t.transform(dataset)
0312 return dataset
0313
0314 @since("1.4.0")
0315 def copy(self, extra=None):
0316 """
0317 Creates a copy of this instance.
0318
0319 :param extra: extra parameters
0320 :returns: new instance
0321 """
0322 if extra is None:
0323 extra = dict()
0324 stages = [stage.copy(extra) for stage in self.stages]
0325 return PipelineModel(stages)
0326
0327 @since("2.0.0")
0328 def write(self):
0329 """Returns an MLWriter instance for this ML instance."""
0330 allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
0331 if allStagesAreJava:
0332 return JavaMLWriter(self)
0333 return PipelineModelWriter(self)
0334
0335 @classmethod
0336 @since("2.0.0")
0337 def read(cls):
0338 """Returns an MLReader instance for this class."""
0339 return PipelineModelReader(cls)
0340
0341 @classmethod
0342 def _from_java(cls, java_stage):
0343 """
0344 Given a Java PipelineModel, create and return a Python wrapper of it.
0345 Used for ML persistence.
0346 """
0347
0348 py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
0349
0350 py_stage = cls(py_stages)
0351 py_stage._resetUid(java_stage.uid())
0352 return py_stage
0353
0354 def _to_java(self):
0355 """
0356 Transfer this instance to a Java PipelineModel. Used for ML persistence.
0357
0358 :return: Java object equivalent to this instance.
0359 """
0360
0361 gateway = SparkContext._gateway
0362 cls = SparkContext._jvm.org.apache.spark.ml.Transformer
0363 java_stages = gateway.new_array(cls, len(self.stages))
0364 for idx, stage in enumerate(self.stages):
0365 java_stages[idx] = stage._to_java()
0366
0367 _java_obj =\
0368 JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
0369
0370 return _java_obj
0371
0372
0373 @inherit_doc
0374 class PipelineSharedReadWrite():
0375 """
0376 Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
0377 :py:class:`Pipeline` and :py:class:`PipelineModel`
0378
0379 .. versionadded:: 2.3.0
0380 """
0381
0382 @staticmethod
0383 def checkStagesForJava(stages):
0384 return all(isinstance(stage, JavaMLWritable) for stage in stages)
0385
0386 @staticmethod
0387 def validateStages(stages):
0388 """
0389 Check that all stages are Writable
0390 """
0391 for stage in stages:
0392 if not isinstance(stage, MLWritable):
0393 raise ValueError("Pipeline write will fail on this pipeline " +
0394 "because stage %s of type %s is not MLWritable",
0395 stage.uid, type(stage))
0396
0397 @staticmethod
0398 def saveImpl(instance, stages, sc, path):
0399 """
0400 Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
0401 - save metadata to path/metadata
0402 - save stages to stages/IDX_UID
0403 """
0404 stageUids = [stage.uid for stage in stages]
0405 jsonParams = {'stageUids': stageUids, 'language': 'Python'}
0406 DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
0407 stagesDir = os.path.join(path, "stages")
0408 for index, stage in enumerate(stages):
0409 stage.write().save(PipelineSharedReadWrite
0410 .getStagePath(stage.uid, index, len(stages), stagesDir))
0411
0412 @staticmethod
0413 def load(metadata, sc, path):
0414 """
0415 Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
0416
0417 :return: (UID, list of stages)
0418 """
0419 stagesDir = os.path.join(path, "stages")
0420 stageUids = metadata['paramMap']['stageUids']
0421 stages = []
0422 for index, stageUid in enumerate(stageUids):
0423 stagePath = \
0424 PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
0425 stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
0426 stages.append(stage)
0427 return (metadata['uid'], stages)
0428
0429 @staticmethod
0430 def getStagePath(stageUid, stageIdx, numStages, stagesDir):
0431 """
0432 Get path for saving the given stage.
0433 """
0434 stageIdxDigits = len(str(numStages))
0435 stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
0436 stagePath = os.path.join(stagesDir, stageDir)
0437 return stagePath