0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import itertools
0018 import sys
0019 from multiprocessing.pool import ThreadPool
0020
0021 import numpy as np
0022
0023 from pyspark import since, keyword_only
0024 from pyspark.ml import Estimator, Model
0025 from pyspark.ml.common import _py2java, _java2py
0026 from pyspark.ml.param import Params, Param, TypeConverters
0027 from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
0028 from pyspark.ml.util import *
0029 from pyspark.ml.wrapper import JavaParams
0030 from pyspark.sql.functions import rand
0031
0032 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
0033 'TrainValidationSplitModel']
0034
0035
0036 def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
0037 """
0038 Creates a list of callables which can be called from different threads to fit and evaluate
0039 an estimator in parallel. Each callable returns an `(index, metric)` pair.
0040
0041 :param est: Estimator, the estimator to be fit.
0042 :param train: DataFrame, training data set, used for fitting.
0043 :param eva: Evaluator, used to compute `metric`
0044 :param validation: DataFrame, validation data set, used for evaluation.
0045 :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
0046 :param collectSubModel: Whether to collect sub model.
0047 :return: (int, float, subModel), an index into `epm` and the associated metric value.
0048 """
0049 modelIter = est.fitMultiple(train, epm)
0050
0051 def singleTask():
0052 index, model = next(modelIter)
0053 metric = eva.evaluate(model.transform(validation, epm[index]))
0054 return index, metric, model if collectSubModel else None
0055
0056 return [singleTask] * len(epm)
0057
0058
0059 class ParamGridBuilder(object):
0060 r"""
0061 Builder for a param grid used in grid search-based model selection.
0062
0063 >>> from pyspark.ml.classification import LogisticRegression
0064 >>> lr = LogisticRegression()
0065 >>> output = ParamGridBuilder() \
0066 ... .baseOn({lr.labelCol: 'l'}) \
0067 ... .baseOn([lr.predictionCol, 'p']) \
0068 ... .addGrid(lr.regParam, [1.0, 2.0]) \
0069 ... .addGrid(lr.maxIter, [1, 5]) \
0070 ... .build()
0071 >>> expected = [
0072 ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
0073 ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
0074 ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
0075 ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
0076 >>> len(output) == len(expected)
0077 True
0078 >>> all([m in expected for m in output])
0079 True
0080
0081 .. versionadded:: 1.4.0
0082 """
0083
0084 def __init__(self):
0085 self._param_grid = {}
0086
0087 @since("1.4.0")
0088 def addGrid(self, param, values):
0089 """
0090 Sets the given parameters in this grid to fixed values.
0091
0092 param must be an instance of Param associated with an instance of Params
0093 (such as Estimator or Transformer).
0094 """
0095 if isinstance(param, Param):
0096 self._param_grid[param] = values
0097 else:
0098 raise TypeError("param must be an instance of Param")
0099
0100 return self
0101
0102 @since("1.4.0")
0103 def baseOn(self, *args):
0104 """
0105 Sets the given parameters in this grid to fixed values.
0106 Accepts either a parameter dictionary or a list of (parameter, value) pairs.
0107 """
0108 if isinstance(args[0], dict):
0109 self.baseOn(*args[0].items())
0110 else:
0111 for (param, value) in args:
0112 self.addGrid(param, [value])
0113
0114 return self
0115
0116 @since("1.4.0")
0117 def build(self):
0118 """
0119 Builds and returns all combinations of parameters specified
0120 by the param grid.
0121 """
0122 keys = self._param_grid.keys()
0123 grid_values = self._param_grid.values()
0124
0125 def to_key_value_pairs(keys, values):
0126 return [(key, key.typeConverter(value)) for key, value in zip(keys, values)]
0127
0128 return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)]
0129
0130
0131 class _ValidatorParams(HasSeed):
0132 """
0133 Common params for TrainValidationSplit and CrossValidator.
0134 """
0135
0136 estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
0137 estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
0138 evaluator = Param(
0139 Params._dummy(), "evaluator",
0140 "evaluator used to select hyper-parameters that maximize the validator metric")
0141
0142 @since("2.0.0")
0143 def getEstimator(self):
0144 """
0145 Gets the value of estimator or its default value.
0146 """
0147 return self.getOrDefault(self.estimator)
0148
0149 @since("2.0.0")
0150 def getEstimatorParamMaps(self):
0151 """
0152 Gets the value of estimatorParamMaps or its default value.
0153 """
0154 return self.getOrDefault(self.estimatorParamMaps)
0155
0156 @since("2.0.0")
0157 def getEvaluator(self):
0158 """
0159 Gets the value of evaluator or its default value.
0160 """
0161 return self.getOrDefault(self.evaluator)
0162
0163 @classmethod
0164 def _from_java_impl(cls, java_stage):
0165 """
0166 Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
0167 """
0168
0169
0170 estimator = JavaParams._from_java(java_stage.getEstimator())
0171 evaluator = JavaParams._from_java(java_stage.getEvaluator())
0172 epms = [estimator._transfer_param_map_from_java(epm)
0173 for epm in java_stage.getEstimatorParamMaps()]
0174 return estimator, epms, evaluator
0175
0176 def _to_java_impl(self):
0177 """
0178 Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
0179 """
0180
0181 gateway = SparkContext._gateway
0182 cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
0183
0184 java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
0185 for idx, epm in enumerate(self.getEstimatorParamMaps()):
0186 java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
0187
0188 java_estimator = self.getEstimator()._to_java()
0189 java_evaluator = self.getEvaluator()._to_java()
0190 return java_estimator, java_epms, java_evaluator
0191
0192
0193 class _CrossValidatorParams(_ValidatorParams):
0194 """
0195 Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
0196
0197 .. versionadded:: 3.0.0
0198 """
0199
0200 numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
0201 typeConverter=TypeConverters.toInt)
0202
0203 @since("1.4.0")
0204 def getNumFolds(self):
0205 """
0206 Gets the value of numFolds or its default value.
0207 """
0208 return self.getOrDefault(self.numFolds)
0209
0210
0211 class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollectSubModels,
0212 MLReadable, MLWritable):
0213 """
0214
0215 K-fold cross validation performs model selection by splitting the dataset into a set of
0216 non-overlapping randomly partitioned folds which are used as separate training and test datasets
0217 e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
0218 each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
0219 test set exactly once.
0220
0221
0222 >>> from pyspark.ml.classification import LogisticRegression
0223 >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
0224 >>> from pyspark.ml.linalg import Vectors
0225 >>> from pyspark.ml.tuning import CrossValidatorModel
0226 >>> import tempfile
0227 >>> dataset = spark.createDataFrame(
0228 ... [(Vectors.dense([0.0]), 0.0),
0229 ... (Vectors.dense([0.4]), 1.0),
0230 ... (Vectors.dense([0.5]), 0.0),
0231 ... (Vectors.dense([0.6]), 1.0),
0232 ... (Vectors.dense([1.0]), 1.0)] * 10,
0233 ... ["features", "label"])
0234 >>> lr = LogisticRegression()
0235 >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
0236 >>> evaluator = BinaryClassificationEvaluator()
0237 >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
0238 ... parallelism=2)
0239 >>> cvModel = cv.fit(dataset)
0240 >>> cvModel.getNumFolds()
0241 3
0242 >>> cvModel.avgMetrics[0]
0243 0.5
0244 >>> path = tempfile.mkdtemp()
0245 >>> model_path = path + "/model"
0246 >>> cvModel.write().save(model_path)
0247 >>> cvModelRead = CrossValidatorModel.read().load(model_path)
0248 >>> cvModelRead.avgMetrics
0249 [0.5, ...
0250 >>> evaluator.evaluate(cvModel.transform(dataset))
0251 0.8333...
0252
0253 .. versionadded:: 1.4.0
0254 """
0255
0256 @keyword_only
0257 def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
0258 seed=None, parallelism=1, collectSubModels=False):
0259 """
0260 __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
0261 seed=None, parallelism=1, collectSubModels=False)
0262 """
0263 super(CrossValidator, self).__init__()
0264 self._setDefault(numFolds=3, parallelism=1)
0265 kwargs = self._input_kwargs
0266 self._set(**kwargs)
0267
0268 @keyword_only
0269 @since("1.4.0")
0270 def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
0271 seed=None, parallelism=1, collectSubModels=False):
0272 """
0273 setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
0274 seed=None, parallelism=1, collectSubModels=False):
0275 Sets params for cross validator.
0276 """
0277 kwargs = self._input_kwargs
0278 return self._set(**kwargs)
0279
0280 @since("2.0.0")
0281 def setEstimator(self, value):
0282 """
0283 Sets the value of :py:attr:`estimator`.
0284 """
0285 return self._set(estimator=value)
0286
0287 @since("2.0.0")
0288 def setEstimatorParamMaps(self, value):
0289 """
0290 Sets the value of :py:attr:`estimatorParamMaps`.
0291 """
0292 return self._set(estimatorParamMaps=value)
0293
0294 @since("2.0.0")
0295 def setEvaluator(self, value):
0296 """
0297 Sets the value of :py:attr:`evaluator`.
0298 """
0299 return self._set(evaluator=value)
0300
0301 @since("1.4.0")
0302 def setNumFolds(self, value):
0303 """
0304 Sets the value of :py:attr:`numFolds`.
0305 """
0306 return self._set(numFolds=value)
0307
0308 def setSeed(self, value):
0309 """
0310 Sets the value of :py:attr:`seed`.
0311 """
0312 return self._set(seed=value)
0313
0314 def setParallelism(self, value):
0315 """
0316 Sets the value of :py:attr:`parallelism`.
0317 """
0318 return self._set(parallelism=value)
0319
0320 def setCollectSubModels(self, value):
0321 """
0322 Sets the value of :py:attr:`collectSubModels`.
0323 """
0324 return self._set(collectSubModels=value)
0325
0326 def _fit(self, dataset):
0327 est = self.getOrDefault(self.estimator)
0328 epm = self.getOrDefault(self.estimatorParamMaps)
0329 numModels = len(epm)
0330 eva = self.getOrDefault(self.evaluator)
0331 nFolds = self.getOrDefault(self.numFolds)
0332 seed = self.getOrDefault(self.seed)
0333 h = 1.0 / nFolds
0334 randCol = self.uid + "_rand"
0335 df = dataset.select("*", rand(seed).alias(randCol))
0336 metrics = [0.0] * numModels
0337
0338 pool = ThreadPool(processes=min(self.getParallelism(), numModels))
0339 subModels = None
0340 collectSubModelsParam = self.getCollectSubModels()
0341 if collectSubModelsParam:
0342 subModels = [[None for j in range(numModels)] for i in range(nFolds)]
0343
0344 for i in range(nFolds):
0345 validateLB = i * h
0346 validateUB = (i + 1) * h
0347 condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
0348 validation = df.filter(condition).cache()
0349 train = df.filter(~condition).cache()
0350
0351 tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
0352 for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
0353 metrics[j] += (metric / nFolds)
0354 if collectSubModelsParam:
0355 subModels[i][j] = subModel
0356
0357 validation.unpersist()
0358 train.unpersist()
0359
0360 if eva.isLargerBetter():
0361 bestIndex = np.argmax(metrics)
0362 else:
0363 bestIndex = np.argmin(metrics)
0364 bestModel = est.fit(dataset, epm[bestIndex])
0365 return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
0366
0367 @since("1.4.0")
0368 def copy(self, extra=None):
0369 """
0370 Creates a copy of this instance with a randomly generated uid
0371 and some extra params. This copies creates a deep copy of
0372 the embedded paramMap, and copies the embedded and extra parameters over.
0373
0374 :param extra: Extra parameters to copy to the new instance
0375 :return: Copy of this instance
0376 """
0377 if extra is None:
0378 extra = dict()
0379 newCV = Params.copy(self, extra)
0380 if self.isSet(self.estimator):
0381 newCV.setEstimator(self.getEstimator().copy(extra))
0382
0383 if self.isSet(self.evaluator):
0384 newCV.setEvaluator(self.getEvaluator().copy(extra))
0385 return newCV
0386
0387 @since("2.3.0")
0388 def write(self):
0389 """Returns an MLWriter instance for this ML instance."""
0390 return JavaMLWriter(self)
0391
0392 @classmethod
0393 @since("2.3.0")
0394 def read(cls):
0395 """Returns an MLReader instance for this class."""
0396 return JavaMLReader(cls)
0397
0398 @classmethod
0399 def _from_java(cls, java_stage):
0400 """
0401 Given a Java CrossValidator, create and return a Python wrapper of it.
0402 Used for ML persistence.
0403 """
0404
0405 estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
0406 numFolds = java_stage.getNumFolds()
0407 seed = java_stage.getSeed()
0408 parallelism = java_stage.getParallelism()
0409 collectSubModels = java_stage.getCollectSubModels()
0410
0411 py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
0412 numFolds=numFolds, seed=seed, parallelism=parallelism,
0413 collectSubModels=collectSubModels)
0414 py_stage._resetUid(java_stage.uid())
0415 return py_stage
0416
0417 def _to_java(self):
0418 """
0419 Transfer this instance to a Java CrossValidator. Used for ML persistence.
0420
0421 :return: Java object equivalent to this instance.
0422 """
0423
0424 estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
0425
0426 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
0427 _java_obj.setEstimatorParamMaps(epms)
0428 _java_obj.setEvaluator(evaluator)
0429 _java_obj.setEstimator(estimator)
0430 _java_obj.setSeed(self.getSeed())
0431 _java_obj.setNumFolds(self.getNumFolds())
0432 _java_obj.setParallelism(self.getParallelism())
0433 _java_obj.setCollectSubModels(self.getCollectSubModels())
0434
0435 return _java_obj
0436
0437
0438 class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
0439 """
0440
0441 CrossValidatorModel contains the model with the highest average cross-validation
0442 metric across folds and uses this model to transform input data. CrossValidatorModel
0443 also tracks the metrics for each param map evaluated.
0444
0445 .. versionadded:: 1.4.0
0446 """
0447
0448 def __init__(self, bestModel, avgMetrics=[], subModels=None):
0449 super(CrossValidatorModel, self).__init__()
0450
0451 self.bestModel = bestModel
0452
0453
0454 self.avgMetrics = avgMetrics
0455
0456 self.subModels = subModels
0457
0458 def _transform(self, dataset):
0459 return self.bestModel.transform(dataset)
0460
0461 @since("1.4.0")
0462 def copy(self, extra=None):
0463 """
0464 Creates a copy of this instance with a randomly generated uid
0465 and some extra params. This copies the underlying bestModel,
0466 creates a deep copy of the embedded paramMap, and
0467 copies the embedded and extra parameters over.
0468 It does not copy the extra Params into the subModels.
0469
0470 :param extra: Extra parameters to copy to the new instance
0471 :return: Copy of this instance
0472 """
0473 if extra is None:
0474 extra = dict()
0475 bestModel = self.bestModel.copy(extra)
0476 avgMetrics = self.avgMetrics
0477 subModels = self.subModels
0478 return CrossValidatorModel(bestModel, avgMetrics, subModels)
0479
0480 @since("2.3.0")
0481 def write(self):
0482 """Returns an MLWriter instance for this ML instance."""
0483 return JavaMLWriter(self)
0484
0485 @classmethod
0486 @since("2.3.0")
0487 def read(cls):
0488 """Returns an MLReader instance for this class."""
0489 return JavaMLReader(cls)
0490
0491 @classmethod
0492 def _from_java(cls, java_stage):
0493 """
0494 Given a Java CrossValidatorModel, create and return a Python wrapper of it.
0495 Used for ML persistence.
0496 """
0497 sc = SparkContext._active_spark_context
0498 bestModel = JavaParams._from_java(java_stage.bestModel())
0499 avgMetrics = _java2py(sc, java_stage.avgMetrics())
0500 estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
0501
0502 py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)._set(estimator=estimator)
0503 py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator)
0504
0505 if java_stage.hasSubModels():
0506 py_stage.subModels = [[JavaParams._from_java(sub_model)
0507 for sub_model in fold_sub_models]
0508 for fold_sub_models in java_stage.subModels()]
0509
0510 py_stage._resetUid(java_stage.uid())
0511 return py_stage
0512
0513 def _to_java(self):
0514 """
0515 Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
0516
0517 :return: Java object equivalent to this instance.
0518 """
0519
0520 sc = SparkContext._active_spark_context
0521 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
0522 self.uid,
0523 self.bestModel._to_java(),
0524 _py2java(sc, self.avgMetrics))
0525 estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
0526
0527 _java_obj.set("evaluator", evaluator)
0528 _java_obj.set("estimator", estimator)
0529 _java_obj.set("estimatorParamMaps", epms)
0530
0531 if self.subModels is not None:
0532 java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models]
0533 for fold_sub_models in self.subModels]
0534 _java_obj.setSubModels(java_sub_models)
0535 return _java_obj
0536
0537
0538 class _TrainValidationSplitParams(_ValidatorParams):
0539 """
0540 Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.
0541
0542 .. versionadded:: 3.0.0
0543 """
0544
0545 trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
0546 validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat)
0547
0548 @since("2.0.0")
0549 def getTrainRatio(self):
0550 """
0551 Gets the value of trainRatio or its default value.
0552 """
0553 return self.getOrDefault(self.trainRatio)
0554
0555
0556 class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelism,
0557 HasCollectSubModels, MLReadable, MLWritable):
0558 """
0559 Validation for hyper-parameter tuning. Randomly splits the input dataset into train and
0560 validation sets, and uses evaluation metric on the validation set to select the best model.
0561 Similar to :class:`CrossValidator`, but only splits the set once.
0562
0563 >>> from pyspark.ml.classification import LogisticRegression
0564 >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
0565 >>> from pyspark.ml.linalg import Vectors
0566 >>> from pyspark.ml.tuning import TrainValidationSplitModel
0567 >>> import tempfile
0568 >>> dataset = spark.createDataFrame(
0569 ... [(Vectors.dense([0.0]), 0.0),
0570 ... (Vectors.dense([0.4]), 1.0),
0571 ... (Vectors.dense([0.5]), 0.0),
0572 ... (Vectors.dense([0.6]), 1.0),
0573 ... (Vectors.dense([1.0]), 1.0)] * 10,
0574 ... ["features", "label"]).repartition(1)
0575 >>> lr = LogisticRegression()
0576 >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
0577 >>> evaluator = BinaryClassificationEvaluator()
0578 >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
0579 ... parallelism=1, seed=42)
0580 >>> tvsModel = tvs.fit(dataset)
0581 >>> tvsModel.getTrainRatio()
0582 0.75
0583 >>> tvsModel.validationMetrics
0584 [0.5, ...
0585 >>> path = tempfile.mkdtemp()
0586 >>> model_path = path + "/model"
0587 >>> tvsModel.write().save(model_path)
0588 >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
0589 >>> tvsModelRead.validationMetrics
0590 [0.5, ...
0591 >>> evaluator.evaluate(tvsModel.transform(dataset))
0592 0.833...
0593
0594 .. versionadded:: 2.0.0
0595 """
0596
0597 @keyword_only
0598 def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
0599 parallelism=1, collectSubModels=False, seed=None):
0600 """
0601 __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
0602 parallelism=1, collectSubModels=False, seed=None)
0603 """
0604 super(TrainValidationSplit, self).__init__()
0605 self._setDefault(trainRatio=0.75, parallelism=1)
0606 kwargs = self._input_kwargs
0607 self._set(**kwargs)
0608
0609 @since("2.0.0")
0610 @keyword_only
0611 def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
0612 parallelism=1, collectSubModels=False, seed=None):
0613 """
0614 setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
0615 parallelism=1, collectSubModels=False, seed=None):
0616 Sets params for the train validation split.
0617 """
0618 kwargs = self._input_kwargs
0619 return self._set(**kwargs)
0620
0621 @since("2.0.0")
0622 def setEstimator(self, value):
0623 """
0624 Sets the value of :py:attr:`estimator`.
0625 """
0626 return self._set(estimator=value)
0627
0628 @since("2.0.0")
0629 def setEstimatorParamMaps(self, value):
0630 """
0631 Sets the value of :py:attr:`estimatorParamMaps`.
0632 """
0633 return self._set(estimatorParamMaps=value)
0634
0635 @since("2.0.0")
0636 def setEvaluator(self, value):
0637 """
0638 Sets the value of :py:attr:`evaluator`.
0639 """
0640 return self._set(evaluator=value)
0641
0642 @since("2.0.0")
0643 def setTrainRatio(self, value):
0644 """
0645 Sets the value of :py:attr:`trainRatio`.
0646 """
0647 return self._set(trainRatio=value)
0648
0649 def setSeed(self, value):
0650 """
0651 Sets the value of :py:attr:`seed`.
0652 """
0653 return self._set(seed=value)
0654
0655 def setParallelism(self, value):
0656 """
0657 Sets the value of :py:attr:`parallelism`.
0658 """
0659 return self._set(parallelism=value)
0660
0661 def setCollectSubModels(self, value):
0662 """
0663 Sets the value of :py:attr:`collectSubModels`.
0664 """
0665 return self._set(collectSubModels=value)
0666
0667 def _fit(self, dataset):
0668 est = self.getOrDefault(self.estimator)
0669 epm = self.getOrDefault(self.estimatorParamMaps)
0670 numModels = len(epm)
0671 eva = self.getOrDefault(self.evaluator)
0672 tRatio = self.getOrDefault(self.trainRatio)
0673 seed = self.getOrDefault(self.seed)
0674 randCol = self.uid + "_rand"
0675 df = dataset.select("*", rand(seed).alias(randCol))
0676 condition = (df[randCol] >= tRatio)
0677 validation = df.filter(condition).cache()
0678 train = df.filter(~condition).cache()
0679
0680 subModels = None
0681 collectSubModelsParam = self.getCollectSubModels()
0682 if collectSubModelsParam:
0683 subModels = [None for i in range(numModels)]
0684
0685 tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
0686 pool = ThreadPool(processes=min(self.getParallelism(), numModels))
0687 metrics = [None] * numModels
0688 for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
0689 metrics[j] = metric
0690 if collectSubModelsParam:
0691 subModels[j] = subModel
0692
0693 train.unpersist()
0694 validation.unpersist()
0695
0696 if eva.isLargerBetter():
0697 bestIndex = np.argmax(metrics)
0698 else:
0699 bestIndex = np.argmin(metrics)
0700 bestModel = est.fit(dataset, epm[bestIndex])
0701 return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels))
0702
0703 @since("2.0.0")
0704 def copy(self, extra=None):
0705 """
0706 Creates a copy of this instance with a randomly generated uid
0707 and some extra params. This copies creates a deep copy of
0708 the embedded paramMap, and copies the embedded and extra parameters over.
0709
0710 :param extra: Extra parameters to copy to the new instance
0711 :return: Copy of this instance
0712 """
0713 if extra is None:
0714 extra = dict()
0715 newTVS = Params.copy(self, extra)
0716 if self.isSet(self.estimator):
0717 newTVS.setEstimator(self.getEstimator().copy(extra))
0718
0719 if self.isSet(self.evaluator):
0720 newTVS.setEvaluator(self.getEvaluator().copy(extra))
0721 return newTVS
0722
0723 @since("2.3.0")
0724 def write(self):
0725 """Returns an MLWriter instance for this ML instance."""
0726 return JavaMLWriter(self)
0727
0728 @classmethod
0729 @since("2.3.0")
0730 def read(cls):
0731 """Returns an MLReader instance for this class."""
0732 return JavaMLReader(cls)
0733
0734 @classmethod
0735 def _from_java(cls, java_stage):
0736 """
0737 Given a Java TrainValidationSplit, create and return a Python wrapper of it.
0738 Used for ML persistence.
0739 """
0740
0741 estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
0742 trainRatio = java_stage.getTrainRatio()
0743 seed = java_stage.getSeed()
0744 parallelism = java_stage.getParallelism()
0745 collectSubModels = java_stage.getCollectSubModels()
0746
0747 py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
0748 trainRatio=trainRatio, seed=seed, parallelism=parallelism,
0749 collectSubModels=collectSubModels)
0750 py_stage._resetUid(java_stage.uid())
0751 return py_stage
0752
0753 def _to_java(self):
0754 """
0755 Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
0756 :return: Java object equivalent to this instance.
0757 """
0758
0759 estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
0760
0761 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
0762 self.uid)
0763 _java_obj.setEstimatorParamMaps(epms)
0764 _java_obj.setEvaluator(evaluator)
0765 _java_obj.setEstimator(estimator)
0766 _java_obj.setTrainRatio(self.getTrainRatio())
0767 _java_obj.setSeed(self.getSeed())
0768 _java_obj.setParallelism(self.getParallelism())
0769 _java_obj.setCollectSubModels(self.getCollectSubModels())
0770 return _java_obj
0771
0772
0773 class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, MLWritable):
0774 """
0775 Model from train validation split.
0776
0777 .. versionadded:: 2.0.0
0778 """
0779
0780 def __init__(self, bestModel, validationMetrics=[], subModels=None):
0781 super(TrainValidationSplitModel, self).__init__()
0782
0783 self.bestModel = bestModel
0784
0785 self.validationMetrics = validationMetrics
0786
0787 self.subModels = subModels
0788
0789 def _transform(self, dataset):
0790 return self.bestModel.transform(dataset)
0791
0792 @since("2.0.0")
0793 def copy(self, extra=None):
0794 """
0795 Creates a copy of this instance with a randomly generated uid
0796 and some extra params. This copies the underlying bestModel,
0797 creates a deep copy of the embedded paramMap, and
0798 copies the embedded and extra parameters over.
0799 And, this creates a shallow copy of the validationMetrics.
0800 It does not copy the extra Params into the subModels.
0801
0802 :param extra: Extra parameters to copy to the new instance
0803 :return: Copy of this instance
0804 """
0805 if extra is None:
0806 extra = dict()
0807 bestModel = self.bestModel.copy(extra)
0808 validationMetrics = list(self.validationMetrics)
0809 subModels = self.subModels
0810 return TrainValidationSplitModel(bestModel, validationMetrics, subModels)
0811
0812 @since("2.3.0")
0813 def write(self):
0814 """Returns an MLWriter instance for this ML instance."""
0815 return JavaMLWriter(self)
0816
0817 @classmethod
0818 @since("2.3.0")
0819 def read(cls):
0820 """Returns an MLReader instance for this class."""
0821 return JavaMLReader(cls)
0822
0823 @classmethod
0824 def _from_java(cls, java_stage):
0825 """
0826 Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
0827 Used for ML persistence.
0828 """
0829
0830
0831 sc = SparkContext._active_spark_context
0832 bestModel = JavaParams._from_java(java_stage.bestModel())
0833 validationMetrics = _java2py(sc, java_stage.validationMetrics())
0834 estimator, epms, evaluator = super(TrainValidationSplitModel,
0835 cls)._from_java_impl(java_stage)
0836
0837 py_stage = cls(bestModel=bestModel,
0838 validationMetrics=validationMetrics)._set(estimator=estimator)
0839 py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator)
0840
0841 if java_stage.hasSubModels():
0842 py_stage.subModels = [JavaParams._from_java(sub_model)
0843 for sub_model in java_stage.subModels()]
0844
0845 py_stage._resetUid(java_stage.uid())
0846 return py_stage
0847
0848 def _to_java(self):
0849 """
0850 Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
0851 :return: Java object equivalent to this instance.
0852 """
0853
0854 sc = SparkContext._active_spark_context
0855 _java_obj = JavaParams._new_java_obj(
0856 "org.apache.spark.ml.tuning.TrainValidationSplitModel",
0857 self.uid,
0858 self.bestModel._to_java(),
0859 _py2java(sc, self.validationMetrics))
0860 estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
0861
0862 _java_obj.set("evaluator", evaluator)
0863 _java_obj.set("estimator", estimator)
0864 _java_obj.set("estimatorParamMaps", epms)
0865
0866 if self.subModels is not None:
0867 java_sub_models = [sub_model._to_java() for sub_model in self.subModels]
0868 _java_obj.setSubModels(java_sub_models)
0869
0870 return _java_obj
0871
0872
0873 if __name__ == "__main__":
0874 import doctest
0875
0876 from pyspark.sql import SparkSession
0877 globs = globals().copy()
0878
0879
0880
0881 spark = SparkSession.builder\
0882 .master("local[2]")\
0883 .appName("ml.tuning tests")\
0884 .getOrCreate()
0885 sc = spark.sparkContext
0886 globs['sc'] = sc
0887 globs['spark'] = spark
0888 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0889 spark.stop()
0890 if failure_count:
0891 sys.exit(-1)