Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from abc import ABCMeta, abstractmethod
0019 
0020 import copy
0021 import threading
0022 
0023 from pyspark import since
0024 from pyspark.ml.param.shared import *
0025 from pyspark.ml.common import inherit_doc
0026 from pyspark.sql.functions import udf
0027 from pyspark.sql.types import StructField, StructType
0028 
0029 
0030 class _FitMultipleIterator(object):
0031     """
0032     Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
0033     iterator. This class handles the simple case of fitMultiple where each param map should be
0034     fit independently.
0035 
0036     :param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset.
0037         `fitSingleModel` may be called up to `numModels` times, with a unique index each time.
0038         Each call to `fitSingleModel` with an index should return the Model associated with
0039         that index.
0040     :param numModel: Number of models this iterator should produce.
0041 
0042     See Estimator.fitMultiple for more info.
0043     """
0044     def __init__(self, fitSingleModel, numModels):
0045         """
0046 
0047         """
0048         self.fitSingleModel = fitSingleModel
0049         self.numModel = numModels
0050         self.counter = 0
0051         self.lock = threading.Lock()
0052 
0053     def __iter__(self):
0054         return self
0055 
0056     def __next__(self):
0057         with self.lock:
0058             index = self.counter
0059             if index >= self.numModel:
0060                 raise StopIteration("No models remaining.")
0061             self.counter += 1
0062         return index, self.fitSingleModel(index)
0063 
0064     def next(self):
0065         """For python2 compatibility."""
0066         return self.__next__()
0067 
0068 
0069 @inherit_doc
0070 class Estimator(Params):
0071     """
0072     Abstract class for estimators that fit models to data.
0073 
0074     .. versionadded:: 1.3.0
0075     """
0076 
0077     __metaclass__ = ABCMeta
0078 
0079     @abstractmethod
0080     def _fit(self, dataset):
0081         """
0082         Fits a model to the input dataset. This is called by the default implementation of fit.
0083 
0084         :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0085         :returns: fitted model
0086         """
0087         raise NotImplementedError()
0088 
0089     @since("2.3.0")
0090     def fitMultiple(self, dataset, paramMaps):
0091         """
0092         Fits a model to the input dataset for each param map in `paramMaps`.
0093 
0094         :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
0095         :param paramMaps: A Sequence of param maps.
0096         :return: A thread safe iterable which contains one model for each param map. Each
0097                  call to `next(modelIterator)` will return `(index, model)` where model was fit
0098                  using `paramMaps[index]`. `index` values may not be sequential.
0099         """
0100         estimator = self.copy()
0101 
0102         def fitSingleModel(index):
0103             return estimator.fit(dataset, paramMaps[index])
0104 
0105         return _FitMultipleIterator(fitSingleModel, len(paramMaps))
0106 
0107     @since("1.3.0")
0108     def fit(self, dataset, params=None):
0109         """
0110         Fits a model to the input dataset with optional parameters.
0111 
0112         :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0113         :param params: an optional param map that overrides embedded params. If a list/tuple of
0114                        param maps is given, this calls fit on each param map and returns a list of
0115                        models.
0116         :returns: fitted model(s)
0117         """
0118         if params is None:
0119             params = dict()
0120         if isinstance(params, (list, tuple)):
0121             models = [None] * len(params)
0122             for index, model in self.fitMultiple(dataset, params):
0123                 models[index] = model
0124             return models
0125         elif isinstance(params, dict):
0126             if params:
0127                 return self.copy(params)._fit(dataset)
0128             else:
0129                 return self._fit(dataset)
0130         else:
0131             raise ValueError("Params must be either a param map or a list/tuple of param maps, "
0132                              "but got %s." % type(params))
0133 
0134 
0135 @inherit_doc
0136 class Transformer(Params):
0137     """
0138     Abstract class for transformers that transform one dataset into another.
0139 
0140     .. versionadded:: 1.3.0
0141     """
0142 
0143     __metaclass__ = ABCMeta
0144 
0145     @abstractmethod
0146     def _transform(self, dataset):
0147         """
0148         Transforms the input dataset.
0149 
0150         :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0151         :returns: transformed dataset
0152         """
0153         raise NotImplementedError()
0154 
0155     @since("1.3.0")
0156     def transform(self, dataset, params=None):
0157         """
0158         Transforms the input dataset with optional parameters.
0159 
0160         :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
0161         :param params: an optional param map that overrides embedded params.
0162         :returns: transformed dataset
0163         """
0164         if params is None:
0165             params = dict()
0166         if isinstance(params, dict):
0167             if params:
0168                 return self.copy(params)._transform(dataset)
0169             else:
0170                 return self._transform(dataset)
0171         else:
0172             raise ValueError("Params must be a param map but got %s." % type(params))
0173 
0174 
0175 @inherit_doc
0176 class Model(Transformer):
0177     """
0178     Abstract class for models that are fitted by estimators.
0179 
0180     .. versionadded:: 1.4.0
0181     """
0182 
0183     __metaclass__ = ABCMeta
0184 
0185 
0186 @inherit_doc
0187 class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
0188     """
0189     Abstract class for transformers that take one input column, apply transformation,
0190     and output the result as a new column.
0191 
0192     .. versionadded:: 2.3.0
0193     """
0194 
0195     def setInputCol(self, value):
0196         """
0197         Sets the value of :py:attr:`inputCol`.
0198         """
0199         return self._set(inputCol=value)
0200 
0201     def setOutputCol(self, value):
0202         """
0203         Sets the value of :py:attr:`outputCol`.
0204         """
0205         return self._set(outputCol=value)
0206 
0207     @abstractmethod
0208     def createTransformFunc(self):
0209         """
0210         Creates the transform function using the given param map. The input param map already takes
0211         account of the embedded param map. So the param values should be determined
0212         solely by the input param map.
0213         """
0214         raise NotImplementedError()
0215 
0216     @abstractmethod
0217     def outputDataType(self):
0218         """
0219         Returns the data type of the output column.
0220         """
0221         raise NotImplementedError()
0222 
0223     @abstractmethod
0224     def validateInputType(self, inputType):
0225         """
0226         Validates the input type. Throw an exception if it is invalid.
0227         """
0228         raise NotImplementedError()
0229 
0230     def transformSchema(self, schema):
0231         inputType = schema[self.getInputCol()].dataType
0232         self.validateInputType(inputType)
0233         if self.getOutputCol() in schema.names:
0234             raise ValueError("Output column %s already exists." % self.getOutputCol())
0235         outputFields = copy.copy(schema.fields)
0236         outputFields.append(StructField(self.getOutputCol(),
0237                                         self.outputDataType(),
0238                                         nullable=False))
0239         return StructType(outputFields)
0240 
0241     def _transform(self, dataset):
0242         self.transformSchema(dataset.schema)
0243         transformUDF = udf(self.createTransformFunc(), self.outputDataType())
0244         transformedDataset = dataset.withColumn(self.getOutputCol(),
0245                                                 transformUDF(dataset[self.getInputCol()]))
0246         return transformedDataset