0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from abc import ABCMeta, abstractmethod
0019
0020 import copy
0021 import threading
0022
0023 from pyspark import since
0024 from pyspark.ml.param.shared import *
0025 from pyspark.ml.common import inherit_doc
0026 from pyspark.sql.functions import udf
0027 from pyspark.sql.types import StructField, StructType
0028
0029
0030 class _FitMultipleIterator(object):
0031 """
0032 Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
0033 iterator. This class handles the simple case of fitMultiple where each param map should be
0034 fit independently.
0035
0036 :param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset.
0037 `fitSingleModel` may be called up to `numModels` times, with a unique index each time.
0038 Each call to `fitSingleModel` with an index should return the Model associated with
0039 that index.
0040 :param numModel: Number of models this iterator should produce.
0041
0042 See Estimator.fitMultiple for more info.
0043 """
0044 def __init__(self, fitSingleModel, numModels):
0045 """
0046
0047 """
0048 self.fitSingleModel = fitSingleModel
0049 self.numModel = numModels
0050 self.counter = 0
0051 self.lock = threading.Lock()
0052
0053 def __iter__(self):
0054 return self
0055
0056 def __next__(self):
0057 with self.lock:
0058 index = self.counter
0059 if index >= self.numModel:
0060 raise StopIteration("No models remaining.")
0061 self.counter += 1
0062 return index, self.fitSingleModel(index)
0063
0064 def next(self):
0065 """For python2 compatibility."""
0066 return self.__next__()
0067
0068
0069 @inherit_doc
0070 class Estimator(Params):
0071 """
0072 Abstract class for estimators that fit models to data.
0073
0074 .. versionadded:: 1.3.0
0075 """
0076
0077 __metaclass__ = ABCMeta
0078
0079 @abstractmethod
0080 def _fit(self, dataset):
0081 """
0082 Fits a model to the input dataset. This is called by the default implementation of fit.
0083
0084 :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0085 :returns: fitted model
0086 """
0087 raise NotImplementedError()
0088
0089 @since("2.3.0")
0090 def fitMultiple(self, dataset, paramMaps):
0091 """
0092 Fits a model to the input dataset for each param map in `paramMaps`.
0093
0094 :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
0095 :param paramMaps: A Sequence of param maps.
0096 :return: A thread safe iterable which contains one model for each param map. Each
0097 call to `next(modelIterator)` will return `(index, model)` where model was fit
0098 using `paramMaps[index]`. `index` values may not be sequential.
0099 """
0100 estimator = self.copy()
0101
0102 def fitSingleModel(index):
0103 return estimator.fit(dataset, paramMaps[index])
0104
0105 return _FitMultipleIterator(fitSingleModel, len(paramMaps))
0106
0107 @since("1.3.0")
0108 def fit(self, dataset, params=None):
0109 """
0110 Fits a model to the input dataset with optional parameters.
0111
0112 :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0113 :param params: an optional param map that overrides embedded params. If a list/tuple of
0114 param maps is given, this calls fit on each param map and returns a list of
0115 models.
0116 :returns: fitted model(s)
0117 """
0118 if params is None:
0119 params = dict()
0120 if isinstance(params, (list, tuple)):
0121 models = [None] * len(params)
0122 for index, model in self.fitMultiple(dataset, params):
0123 models[index] = model
0124 return models
0125 elif isinstance(params, dict):
0126 if params:
0127 return self.copy(params)._fit(dataset)
0128 else:
0129 return self._fit(dataset)
0130 else:
0131 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
0132 "but got %s." % type(params))
0133
0134
0135 @inherit_doc
0136 class Transformer(Params):
0137 """
0138 Abstract class for transformers that transform one dataset into another.
0139
0140 .. versionadded:: 1.3.0
0141 """
0142
0143 __metaclass__ = ABCMeta
0144
0145 @abstractmethod
0146 def _transform(self, dataset):
0147 """
0148 Transforms the input dataset.
0149
0150 :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0151 :returns: transformed dataset
0152 """
0153 raise NotImplementedError()
0154
0155 @since("1.3.0")
0156 def transform(self, dataset, params=None):
0157 """
0158 Transforms the input dataset with optional parameters.
0159
0160 :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0161 :param params: an optional param map that overrides embedded params.
0162 :returns: transformed dataset
0163 """
0164 if params is None:
0165 params = dict()
0166 if isinstance(params, dict):
0167 if params:
0168 return self.copy(params)._transform(dataset)
0169 else:
0170 return self._transform(dataset)
0171 else:
0172 raise ValueError("Params must be a param map but got %s." % type(params))
0173
0174
0175 @inherit_doc
0176 class Model(Transformer):
0177 """
0178 Abstract class for models that are fitted by estimators.
0179
0180 .. versionadded:: 1.4.0
0181 """
0182
0183 __metaclass__ = ABCMeta
0184
0185
0186 @inherit_doc
0187 class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
0188 """
0189 Abstract class for transformers that take one input column, apply transformation,
0190 and output the result as a new column.
0191
0192 .. versionadded:: 2.3.0
0193 """
0194
0195 def setInputCol(self, value):
0196 """
0197 Sets the value of :py:attr:`inputCol`.
0198 """
0199 return self._set(inputCol=value)
0200
0201 def setOutputCol(self, value):
0202 """
0203 Sets the value of :py:attr:`outputCol`.
0204 """
0205 return self._set(outputCol=value)
0206
0207 @abstractmethod
0208 def createTransformFunc(self):
0209 """
0210 Creates the transform function using the given param map. The input param map already takes
0211 account of the embedded param map. So the param values should be determined
0212 solely by the input param map.
0213 """
0214 raise NotImplementedError()
0215
0216 @abstractmethod
0217 def outputDataType(self):
0218 """
0219 Returns the data type of the output column.
0220 """
0221 raise NotImplementedError()
0222
0223 @abstractmethod
0224 def validateInputType(self, inputType):
0225 """
0226 Validates the input type. Throw an exception if it is invalid.
0227 """
0228 raise NotImplementedError()
0229
0230 def transformSchema(self, schema):
0231 inputType = schema[self.getInputCol()].dataType
0232 self.validateInputType(inputType)
0233 if self.getOutputCol() in schema.names:
0234 raise ValueError("Output column %s already exists." % self.getOutputCol())
0235 outputFields = copy.copy(schema.fields)
0236 outputFields.append(StructField(self.getOutputCol(),
0237 self.outputDataType(),
0238 nullable=False))
0239 return StructType(outputFields)
0240
0241 def _transform(self, dataset):
0242 self.transformSchema(dataset.schema)
0243 transformUDF = udf(self.createTransformFunc(), self.outputDataType())
0244 transformedDataset = dataset.withColumn(self.getOutputCol(),
0245 transformUDF(dataset[self.getInputCol()]))
0246 return transformedDataset