Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import itertools
0018 import sys
0019 from multiprocessing.pool import ThreadPool
0020 
0021 import numpy as np
0022 
0023 from pyspark import since, keyword_only
0024 from pyspark.ml import Estimator, Model
0025 from pyspark.ml.common import _py2java, _java2py
0026 from pyspark.ml.param import Params, Param, TypeConverters
0027 from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
0028 from pyspark.ml.util import *
0029 from pyspark.ml.wrapper import JavaParams
0030 from pyspark.sql.functions import rand
0031 
0032 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
0033            'TrainValidationSplitModel']
0034 
0035 
0036 def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
0037     """
0038     Creates a list of callables which can be called from different threads to fit and evaluate
0039     an estimator in parallel. Each callable returns an `(index, metric)` pair.
0040 
0041     :param est: Estimator, the estimator to be fit.
0042     :param train: DataFrame, training data set, used for fitting.
0043     :param eva: Evaluator, used to compute `metric`
0044     :param validation: DataFrame, validation data set, used for evaluation.
0045     :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
0046     :param collectSubModel: Whether to collect sub model.
0047     :return: (int, float, subModel), an index into `epm` and the associated metric value.
0048     """
0049     modelIter = est.fitMultiple(train, epm)
0050 
0051     def singleTask():
0052         index, model = next(modelIter)
0053         metric = eva.evaluate(model.transform(validation, epm[index]))
0054         return index, metric, model if collectSubModel else None
0055 
0056     return [singleTask] * len(epm)
0057 
0058 
0059 class ParamGridBuilder(object):
0060     r"""
0061     Builder for a param grid used in grid search-based model selection.
0062 
0063     >>> from pyspark.ml.classification import LogisticRegression
0064     >>> lr = LogisticRegression()
0065     >>> output = ParamGridBuilder() \
0066     ...     .baseOn({lr.labelCol: 'l'}) \
0067     ...     .baseOn([lr.predictionCol, 'p']) \
0068     ...     .addGrid(lr.regParam, [1.0, 2.0]) \
0069     ...     .addGrid(lr.maxIter, [1, 5]) \
0070     ...     .build()
0071     >>> expected = [
0072     ...     {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
0073     ...     {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
0074     ...     {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
0075     ...     {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
0076     >>> len(output) == len(expected)
0077     True
0078     >>> all([m in expected for m in output])
0079     True
0080 
0081     .. versionadded:: 1.4.0
0082     """
0083 
0084     def __init__(self):
0085         self._param_grid = {}
0086 
0087     @since("1.4.0")
0088     def addGrid(self, param, values):
0089         """
0090         Sets the given parameters in this grid to fixed values.
0091 
0092         param must be an instance of Param associated with an instance of Params
0093         (such as Estimator or Transformer).
0094         """
0095         if isinstance(param, Param):
0096             self._param_grid[param] = values
0097         else:
0098             raise TypeError("param must be an instance of Param")
0099 
0100         return self
0101 
0102     @since("1.4.0")
0103     def baseOn(self, *args):
0104         """
0105         Sets the given parameters in this grid to fixed values.
0106         Accepts either a parameter dictionary or a list of (parameter, value) pairs.
0107         """
0108         if isinstance(args[0], dict):
0109             self.baseOn(*args[0].items())
0110         else:
0111             for (param, value) in args:
0112                 self.addGrid(param, [value])
0113 
0114         return self
0115 
0116     @since("1.4.0")
0117     def build(self):
0118         """
0119         Builds and returns all combinations of parameters specified
0120         by the param grid.
0121         """
0122         keys = self._param_grid.keys()
0123         grid_values = self._param_grid.values()
0124 
0125         def to_key_value_pairs(keys, values):
0126             return [(key, key.typeConverter(value)) for key, value in zip(keys, values)]
0127 
0128         return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)]
0129 
0130 
0131 class _ValidatorParams(HasSeed):
0132     """
0133     Common params for TrainValidationSplit and CrossValidator.
0134     """
0135 
0136     estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
0137     estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
0138     evaluator = Param(
0139         Params._dummy(), "evaluator",
0140         "evaluator used to select hyper-parameters that maximize the validator metric")
0141 
0142     @since("2.0.0")
0143     def getEstimator(self):
0144         """
0145         Gets the value of estimator or its default value.
0146         """
0147         return self.getOrDefault(self.estimator)
0148 
0149     @since("2.0.0")
0150     def getEstimatorParamMaps(self):
0151         """
0152         Gets the value of estimatorParamMaps or its default value.
0153         """
0154         return self.getOrDefault(self.estimatorParamMaps)
0155 
0156     @since("2.0.0")
0157     def getEvaluator(self):
0158         """
0159         Gets the value of evaluator or its default value.
0160         """
0161         return self.getOrDefault(self.evaluator)
0162 
0163     @classmethod
0164     def _from_java_impl(cls, java_stage):
0165         """
0166         Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
0167         """
0168 
0169         # Load information from java_stage to the instance.
0170         estimator = JavaParams._from_java(java_stage.getEstimator())
0171         evaluator = JavaParams._from_java(java_stage.getEvaluator())
0172         epms = [estimator._transfer_param_map_from_java(epm)
0173                 for epm in java_stage.getEstimatorParamMaps()]
0174         return estimator, epms, evaluator
0175 
0176     def _to_java_impl(self):
0177         """
0178         Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
0179         """
0180 
0181         gateway = SparkContext._gateway
0182         cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
0183 
0184         java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
0185         for idx, epm in enumerate(self.getEstimatorParamMaps()):
0186             java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
0187 
0188         java_estimator = self.getEstimator()._to_java()
0189         java_evaluator = self.getEvaluator()._to_java()
0190         return java_estimator, java_epms, java_evaluator
0191 
0192 
0193 class _CrossValidatorParams(_ValidatorParams):
0194     """
0195     Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
0196 
0197     .. versionadded:: 3.0.0
0198     """
0199 
0200     numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
0201                      typeConverter=TypeConverters.toInt)
0202 
0203     @since("1.4.0")
0204     def getNumFolds(self):
0205         """
0206         Gets the value of numFolds or its default value.
0207         """
0208         return self.getOrDefault(self.numFolds)
0209 
0210 
0211 class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollectSubModels,
0212                      MLReadable, MLWritable):
0213     """
0214 
0215     K-fold cross validation performs model selection by splitting the dataset into a set of
0216     non-overlapping randomly partitioned folds which are used as separate training and test datasets
0217     e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
0218     each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
0219     test set exactly once.
0220 
0221 
0222     >>> from pyspark.ml.classification import LogisticRegression
0223     >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
0224     >>> from pyspark.ml.linalg import Vectors
0225     >>> from pyspark.ml.tuning import CrossValidatorModel
0226     >>> import tempfile
0227     >>> dataset = spark.createDataFrame(
0228     ...     [(Vectors.dense([0.0]), 0.0),
0229     ...      (Vectors.dense([0.4]), 1.0),
0230     ...      (Vectors.dense([0.5]), 0.0),
0231     ...      (Vectors.dense([0.6]), 1.0),
0232     ...      (Vectors.dense([1.0]), 1.0)] * 10,
0233     ...     ["features", "label"])
0234     >>> lr = LogisticRegression()
0235     >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
0236     >>> evaluator = BinaryClassificationEvaluator()
0237     >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
0238     ...     parallelism=2)
0239     >>> cvModel = cv.fit(dataset)
0240     >>> cvModel.getNumFolds()
0241     3
0242     >>> cvModel.avgMetrics[0]
0243     0.5
0244     >>> path = tempfile.mkdtemp()
0245     >>> model_path = path + "/model"
0246     >>> cvModel.write().save(model_path)
0247     >>> cvModelRead = CrossValidatorModel.read().load(model_path)
0248     >>> cvModelRead.avgMetrics
0249     [0.5, ...
0250     >>> evaluator.evaluate(cvModel.transform(dataset))
0251     0.8333...
0252 
0253     .. versionadded:: 1.4.0
0254     """
0255 
0256     @keyword_only
0257     def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
0258                  seed=None, parallelism=1, collectSubModels=False):
0259         """
0260         __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
0261                  seed=None, parallelism=1, collectSubModels=False)
0262         """
0263         super(CrossValidator, self).__init__()
0264         self._setDefault(numFolds=3, parallelism=1)
0265         kwargs = self._input_kwargs
0266         self._set(**kwargs)
0267 
0268     @keyword_only
0269     @since("1.4.0")
0270     def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
0271                   seed=None, parallelism=1, collectSubModels=False):
0272         """
0273         setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
0274                   seed=None, parallelism=1, collectSubModels=False):
0275         Sets params for cross validator.
0276         """
0277         kwargs = self._input_kwargs
0278         return self._set(**kwargs)
0279 
0280     @since("2.0.0")
0281     def setEstimator(self, value):
0282         """
0283         Sets the value of :py:attr:`estimator`.
0284         """
0285         return self._set(estimator=value)
0286 
0287     @since("2.0.0")
0288     def setEstimatorParamMaps(self, value):
0289         """
0290         Sets the value of :py:attr:`estimatorParamMaps`.
0291         """
0292         return self._set(estimatorParamMaps=value)
0293 
0294     @since("2.0.0")
0295     def setEvaluator(self, value):
0296         """
0297         Sets the value of :py:attr:`evaluator`.
0298         """
0299         return self._set(evaluator=value)
0300 
0301     @since("1.4.0")
0302     def setNumFolds(self, value):
0303         """
0304         Sets the value of :py:attr:`numFolds`.
0305         """
0306         return self._set(numFolds=value)
0307 
0308     def setSeed(self, value):
0309         """
0310         Sets the value of :py:attr:`seed`.
0311         """
0312         return self._set(seed=value)
0313 
0314     def setParallelism(self, value):
0315         """
0316         Sets the value of :py:attr:`parallelism`.
0317         """
0318         return self._set(parallelism=value)
0319 
0320     def setCollectSubModels(self, value):
0321         """
0322         Sets the value of :py:attr:`collectSubModels`.
0323         """
0324         return self._set(collectSubModels=value)
0325 
0326     def _fit(self, dataset):
0327         est = self.getOrDefault(self.estimator)
0328         epm = self.getOrDefault(self.estimatorParamMaps)
0329         numModels = len(epm)
0330         eva = self.getOrDefault(self.evaluator)
0331         nFolds = self.getOrDefault(self.numFolds)
0332         seed = self.getOrDefault(self.seed)
0333         h = 1.0 / nFolds
0334         randCol = self.uid + "_rand"
0335         df = dataset.select("*", rand(seed).alias(randCol))
0336         metrics = [0.0] * numModels
0337 
0338         pool = ThreadPool(processes=min(self.getParallelism(), numModels))
0339         subModels = None
0340         collectSubModelsParam = self.getCollectSubModels()
0341         if collectSubModelsParam:
0342             subModels = [[None for j in range(numModels)] for i in range(nFolds)]
0343 
0344         for i in range(nFolds):
0345             validateLB = i * h
0346             validateUB = (i + 1) * h
0347             condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
0348             validation = df.filter(condition).cache()
0349             train = df.filter(~condition).cache()
0350 
0351             tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
0352             for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
0353                 metrics[j] += (metric / nFolds)
0354                 if collectSubModelsParam:
0355                     subModels[i][j] = subModel
0356 
0357             validation.unpersist()
0358             train.unpersist()
0359 
0360         if eva.isLargerBetter():
0361             bestIndex = np.argmax(metrics)
0362         else:
0363             bestIndex = np.argmin(metrics)
0364         bestModel = est.fit(dataset, epm[bestIndex])
0365         return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
0366 
0367     @since("1.4.0")
0368     def copy(self, extra=None):
0369         """
0370         Creates a copy of this instance with a randomly generated uid
0371         and some extra params. This copies creates a deep copy of
0372         the embedded paramMap, and copies the embedded and extra parameters over.
0373 
0374         :param extra: Extra parameters to copy to the new instance
0375         :return: Copy of this instance
0376         """
0377         if extra is None:
0378             extra = dict()
0379         newCV = Params.copy(self, extra)
0380         if self.isSet(self.estimator):
0381             newCV.setEstimator(self.getEstimator().copy(extra))
0382         # estimatorParamMaps remain the same
0383         if self.isSet(self.evaluator):
0384             newCV.setEvaluator(self.getEvaluator().copy(extra))
0385         return newCV
0386 
0387     @since("2.3.0")
0388     def write(self):
0389         """Returns an MLWriter instance for this ML instance."""
0390         return JavaMLWriter(self)
0391 
0392     @classmethod
0393     @since("2.3.0")
0394     def read(cls):
0395         """Returns an MLReader instance for this class."""
0396         return JavaMLReader(cls)
0397 
0398     @classmethod
0399     def _from_java(cls, java_stage):
0400         """
0401         Given a Java CrossValidator, create and return a Python wrapper of it.
0402         Used for ML persistence.
0403         """
0404 
0405         estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
0406         numFolds = java_stage.getNumFolds()
0407         seed = java_stage.getSeed()
0408         parallelism = java_stage.getParallelism()
0409         collectSubModels = java_stage.getCollectSubModels()
0410         # Create a new instance of this stage.
0411         py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
0412                        numFolds=numFolds, seed=seed, parallelism=parallelism,
0413                        collectSubModels=collectSubModels)
0414         py_stage._resetUid(java_stage.uid())
0415         return py_stage
0416 
0417     def _to_java(self):
0418         """
0419         Transfer this instance to a Java CrossValidator. Used for ML persistence.
0420 
0421         :return: Java object equivalent to this instance.
0422         """
0423 
0424         estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
0425 
0426         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
0427         _java_obj.setEstimatorParamMaps(epms)
0428         _java_obj.setEvaluator(evaluator)
0429         _java_obj.setEstimator(estimator)
0430         _java_obj.setSeed(self.getSeed())
0431         _java_obj.setNumFolds(self.getNumFolds())
0432         _java_obj.setParallelism(self.getParallelism())
0433         _java_obj.setCollectSubModels(self.getCollectSubModels())
0434 
0435         return _java_obj
0436 
0437 
0438 class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
0439     """
0440 
0441     CrossValidatorModel contains the model with the highest average cross-validation
0442     metric across folds and uses this model to transform input data. CrossValidatorModel
0443     also tracks the metrics for each param map evaluated.
0444 
0445     .. versionadded:: 1.4.0
0446     """
0447 
0448     def __init__(self, bestModel, avgMetrics=[], subModels=None):
0449         super(CrossValidatorModel, self).__init__()
0450         #: best model from cross validation
0451         self.bestModel = bestModel
0452         #: Average cross-validation metrics for each paramMap in
0453         #: CrossValidator.estimatorParamMaps, in the corresponding order.
0454         self.avgMetrics = avgMetrics
0455         #: sub model list from cross validation
0456         self.subModels = subModels
0457 
0458     def _transform(self, dataset):
0459         return self.bestModel.transform(dataset)
0460 
0461     @since("1.4.0")
0462     def copy(self, extra=None):
0463         """
0464         Creates a copy of this instance with a randomly generated uid
0465         and some extra params. This copies the underlying bestModel,
0466         creates a deep copy of the embedded paramMap, and
0467         copies the embedded and extra parameters over.
0468         It does not copy the extra Params into the subModels.
0469 
0470         :param extra: Extra parameters to copy to the new instance
0471         :return: Copy of this instance
0472         """
0473         if extra is None:
0474             extra = dict()
0475         bestModel = self.bestModel.copy(extra)
0476         avgMetrics = self.avgMetrics
0477         subModels = self.subModels
0478         return CrossValidatorModel(bestModel, avgMetrics, subModels)
0479 
0480     @since("2.3.0")
0481     def write(self):
0482         """Returns an MLWriter instance for this ML instance."""
0483         return JavaMLWriter(self)
0484 
0485     @classmethod
0486     @since("2.3.0")
0487     def read(cls):
0488         """Returns an MLReader instance for this class."""
0489         return JavaMLReader(cls)
0490 
0491     @classmethod
0492     def _from_java(cls, java_stage):
0493         """
0494         Given a Java CrossValidatorModel, create and return a Python wrapper of it.
0495         Used for ML persistence.
0496         """
0497         sc = SparkContext._active_spark_context
0498         bestModel = JavaParams._from_java(java_stage.bestModel())
0499         avgMetrics = _java2py(sc, java_stage.avgMetrics())
0500         estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
0501 
0502         py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)._set(estimator=estimator)
0503         py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator)
0504 
0505         if java_stage.hasSubModels():
0506             py_stage.subModels = [[JavaParams._from_java(sub_model)
0507                                    for sub_model in fold_sub_models]
0508                                   for fold_sub_models in java_stage.subModels()]
0509 
0510         py_stage._resetUid(java_stage.uid())
0511         return py_stage
0512 
0513     def _to_java(self):
0514         """
0515         Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
0516 
0517         :return: Java object equivalent to this instance.
0518         """
0519 
0520         sc = SparkContext._active_spark_context
0521         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
0522                                              self.uid,
0523                                              self.bestModel._to_java(),
0524                                              _py2java(sc, self.avgMetrics))
0525         estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
0526 
0527         _java_obj.set("evaluator", evaluator)
0528         _java_obj.set("estimator", estimator)
0529         _java_obj.set("estimatorParamMaps", epms)
0530 
0531         if self.subModels is not None:
0532             java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models]
0533                                for fold_sub_models in self.subModels]
0534             _java_obj.setSubModels(java_sub_models)
0535         return _java_obj
0536 
0537 
0538 class _TrainValidationSplitParams(_ValidatorParams):
0539     """
0540     Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.
0541 
0542     .. versionadded:: 3.0.0
0543     """
0544 
0545     trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
0546      validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat)
0547 
0548     @since("2.0.0")
0549     def getTrainRatio(self):
0550         """
0551         Gets the value of trainRatio or its default value.
0552         """
0553         return self.getOrDefault(self.trainRatio)
0554 
0555 
0556 class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelism,
0557                            HasCollectSubModels, MLReadable, MLWritable):
0558     """
0559     Validation for hyper-parameter tuning. Randomly splits the input dataset into train and
0560     validation sets, and uses evaluation metric on the validation set to select the best model.
0561     Similar to :class:`CrossValidator`, but only splits the set once.
0562 
0563     >>> from pyspark.ml.classification import LogisticRegression
0564     >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
0565     >>> from pyspark.ml.linalg import Vectors
0566     >>> from pyspark.ml.tuning import TrainValidationSplitModel
0567     >>> import tempfile
0568     >>> dataset = spark.createDataFrame(
0569     ...     [(Vectors.dense([0.0]), 0.0),
0570     ...      (Vectors.dense([0.4]), 1.0),
0571     ...      (Vectors.dense([0.5]), 0.0),
0572     ...      (Vectors.dense([0.6]), 1.0),
0573     ...      (Vectors.dense([1.0]), 1.0)] * 10,
0574     ...     ["features", "label"]).repartition(1)
0575     >>> lr = LogisticRegression()
0576     >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
0577     >>> evaluator = BinaryClassificationEvaluator()
0578     >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
0579     ...     parallelism=1, seed=42)
0580     >>> tvsModel = tvs.fit(dataset)
0581     >>> tvsModel.getTrainRatio()
0582     0.75
0583     >>> tvsModel.validationMetrics
0584     [0.5, ...
0585     >>> path = tempfile.mkdtemp()
0586     >>> model_path = path + "/model"
0587     >>> tvsModel.write().save(model_path)
0588     >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
0589     >>> tvsModelRead.validationMetrics
0590     [0.5, ...
0591     >>> evaluator.evaluate(tvsModel.transform(dataset))
0592     0.833...
0593 
0594     .. versionadded:: 2.0.0
0595     """
0596 
0597     @keyword_only
0598     def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
0599                  parallelism=1, collectSubModels=False, seed=None):
0600         """
0601         __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
0602                  parallelism=1, collectSubModels=False, seed=None)
0603         """
0604         super(TrainValidationSplit, self).__init__()
0605         self._setDefault(trainRatio=0.75, parallelism=1)
0606         kwargs = self._input_kwargs
0607         self._set(**kwargs)
0608 
0609     @since("2.0.0")
0610     @keyword_only
0611     def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
0612                   parallelism=1, collectSubModels=False, seed=None):
0613         """
0614         setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
0615                   parallelism=1, collectSubModels=False, seed=None):
0616         Sets params for the train validation split.
0617         """
0618         kwargs = self._input_kwargs
0619         return self._set(**kwargs)
0620 
0621     @since("2.0.0")
0622     def setEstimator(self, value):
0623         """
0624         Sets the value of :py:attr:`estimator`.
0625         """
0626         return self._set(estimator=value)
0627 
0628     @since("2.0.0")
0629     def setEstimatorParamMaps(self, value):
0630         """
0631         Sets the value of :py:attr:`estimatorParamMaps`.
0632         """
0633         return self._set(estimatorParamMaps=value)
0634 
0635     @since("2.0.0")
0636     def setEvaluator(self, value):
0637         """
0638         Sets the value of :py:attr:`evaluator`.
0639         """
0640         return self._set(evaluator=value)
0641 
0642     @since("2.0.0")
0643     def setTrainRatio(self, value):
0644         """
0645         Sets the value of :py:attr:`trainRatio`.
0646         """
0647         return self._set(trainRatio=value)
0648 
0649     def setSeed(self, value):
0650         """
0651         Sets the value of :py:attr:`seed`.
0652         """
0653         return self._set(seed=value)
0654 
0655     def setParallelism(self, value):
0656         """
0657         Sets the value of :py:attr:`parallelism`.
0658         """
0659         return self._set(parallelism=value)
0660 
0661     def setCollectSubModels(self, value):
0662         """
0663         Sets the value of :py:attr:`collectSubModels`.
0664         """
0665         return self._set(collectSubModels=value)
0666 
0667     def _fit(self, dataset):
0668         est = self.getOrDefault(self.estimator)
0669         epm = self.getOrDefault(self.estimatorParamMaps)
0670         numModels = len(epm)
0671         eva = self.getOrDefault(self.evaluator)
0672         tRatio = self.getOrDefault(self.trainRatio)
0673         seed = self.getOrDefault(self.seed)
0674         randCol = self.uid + "_rand"
0675         df = dataset.select("*", rand(seed).alias(randCol))
0676         condition = (df[randCol] >= tRatio)
0677         validation = df.filter(condition).cache()
0678         train = df.filter(~condition).cache()
0679 
0680         subModels = None
0681         collectSubModelsParam = self.getCollectSubModels()
0682         if collectSubModelsParam:
0683             subModels = [None for i in range(numModels)]
0684 
0685         tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
0686         pool = ThreadPool(processes=min(self.getParallelism(), numModels))
0687         metrics = [None] * numModels
0688         for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
0689             metrics[j] = metric
0690             if collectSubModelsParam:
0691                 subModels[j] = subModel
0692 
0693         train.unpersist()
0694         validation.unpersist()
0695 
0696         if eva.isLargerBetter():
0697             bestIndex = np.argmax(metrics)
0698         else:
0699             bestIndex = np.argmin(metrics)
0700         bestModel = est.fit(dataset, epm[bestIndex])
0701         return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels))
0702 
0703     @since("2.0.0")
0704     def copy(self, extra=None):
0705         """
0706         Creates a copy of this instance with a randomly generated uid
0707         and some extra params. This copies creates a deep copy of
0708         the embedded paramMap, and copies the embedded and extra parameters over.
0709 
0710         :param extra: Extra parameters to copy to the new instance
0711         :return: Copy of this instance
0712         """
0713         if extra is None:
0714             extra = dict()
0715         newTVS = Params.copy(self, extra)
0716         if self.isSet(self.estimator):
0717             newTVS.setEstimator(self.getEstimator().copy(extra))
0718         # estimatorParamMaps remain the same
0719         if self.isSet(self.evaluator):
0720             newTVS.setEvaluator(self.getEvaluator().copy(extra))
0721         return newTVS
0722 
0723     @since("2.3.0")
0724     def write(self):
0725         """Returns an MLWriter instance for this ML instance."""
0726         return JavaMLWriter(self)
0727 
0728     @classmethod
0729     @since("2.3.0")
0730     def read(cls):
0731         """Returns an MLReader instance for this class."""
0732         return JavaMLReader(cls)
0733 
0734     @classmethod
0735     def _from_java(cls, java_stage):
0736         """
0737         Given a Java TrainValidationSplit, create and return a Python wrapper of it.
0738         Used for ML persistence.
0739         """
0740 
0741         estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
0742         trainRatio = java_stage.getTrainRatio()
0743         seed = java_stage.getSeed()
0744         parallelism = java_stage.getParallelism()
0745         collectSubModels = java_stage.getCollectSubModels()
0746         # Create a new instance of this stage.
0747         py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
0748                        trainRatio=trainRatio, seed=seed, parallelism=parallelism,
0749                        collectSubModels=collectSubModels)
0750         py_stage._resetUid(java_stage.uid())
0751         return py_stage
0752 
0753     def _to_java(self):
0754         """
0755         Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
0756         :return: Java object equivalent to this instance.
0757         """
0758 
0759         estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
0760 
0761         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
0762                                              self.uid)
0763         _java_obj.setEstimatorParamMaps(epms)
0764         _java_obj.setEvaluator(evaluator)
0765         _java_obj.setEstimator(estimator)
0766         _java_obj.setTrainRatio(self.getTrainRatio())
0767         _java_obj.setSeed(self.getSeed())
0768         _java_obj.setParallelism(self.getParallelism())
0769         _java_obj.setCollectSubModels(self.getCollectSubModels())
0770         return _java_obj
0771 
0772 
0773 class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, MLWritable):
0774     """
0775     Model from train validation split.
0776 
0777     .. versionadded:: 2.0.0
0778     """
0779 
0780     def __init__(self, bestModel, validationMetrics=[], subModels=None):
0781         super(TrainValidationSplitModel, self).__init__()
0782         #: best model from train validation split
0783         self.bestModel = bestModel
0784         #: evaluated validation metrics
0785         self.validationMetrics = validationMetrics
0786         #: sub models from train validation split
0787         self.subModels = subModels
0788 
0789     def _transform(self, dataset):
0790         return self.bestModel.transform(dataset)
0791 
0792     @since("2.0.0")
0793     def copy(self, extra=None):
0794         """
0795         Creates a copy of this instance with a randomly generated uid
0796         and some extra params. This copies the underlying bestModel,
0797         creates a deep copy of the embedded paramMap, and
0798         copies the embedded and extra parameters over.
0799         And, this creates a shallow copy of the validationMetrics.
0800         It does not copy the extra Params into the subModels.
0801 
0802         :param extra: Extra parameters to copy to the new instance
0803         :return: Copy of this instance
0804         """
0805         if extra is None:
0806             extra = dict()
0807         bestModel = self.bestModel.copy(extra)
0808         validationMetrics = list(self.validationMetrics)
0809         subModels = self.subModels
0810         return TrainValidationSplitModel(bestModel, validationMetrics, subModels)
0811 
0812     @since("2.3.0")
0813     def write(self):
0814         """Returns an MLWriter instance for this ML instance."""
0815         return JavaMLWriter(self)
0816 
0817     @classmethod
0818     @since("2.3.0")
0819     def read(cls):
0820         """Returns an MLReader instance for this class."""
0821         return JavaMLReader(cls)
0822 
0823     @classmethod
0824     def _from_java(cls, java_stage):
0825         """
0826         Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
0827         Used for ML persistence.
0828         """
0829 
0830         # Load information from java_stage to the instance.
0831         sc = SparkContext._active_spark_context
0832         bestModel = JavaParams._from_java(java_stage.bestModel())
0833         validationMetrics = _java2py(sc, java_stage.validationMetrics())
0834         estimator, epms, evaluator = super(TrainValidationSplitModel,
0835                                            cls)._from_java_impl(java_stage)
0836         # Create a new instance of this stage.
0837         py_stage = cls(bestModel=bestModel,
0838                        validationMetrics=validationMetrics)._set(estimator=estimator)
0839         py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator)
0840 
0841         if java_stage.hasSubModels():
0842             py_stage.subModels = [JavaParams._from_java(sub_model)
0843                                   for sub_model in java_stage.subModels()]
0844 
0845         py_stage._resetUid(java_stage.uid())
0846         return py_stage
0847 
0848     def _to_java(self):
0849         """
0850         Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
0851         :return: Java object equivalent to this instance.
0852         """
0853 
0854         sc = SparkContext._active_spark_context
0855         _java_obj = JavaParams._new_java_obj(
0856             "org.apache.spark.ml.tuning.TrainValidationSplitModel",
0857             self.uid,
0858             self.bestModel._to_java(),
0859             _py2java(sc, self.validationMetrics))
0860         estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
0861 
0862         _java_obj.set("evaluator", evaluator)
0863         _java_obj.set("estimator", estimator)
0864         _java_obj.set("estimatorParamMaps", epms)
0865 
0866         if self.subModels is not None:
0867             java_sub_models = [sub_model._to_java() for sub_model in self.subModels]
0868             _java_obj.setSubModels(java_sub_models)
0869 
0870         return _java_obj
0871 
0872 
0873 if __name__ == "__main__":
0874     import doctest
0875 
0876     from pyspark.sql import SparkSession
0877     globs = globals().copy()
0878 
0879     # The small batch size here ensures that we see multiple batches,
0880     # even in these small test examples:
0881     spark = SparkSession.builder\
0882         .master("local[2]")\
0883         .appName("ml.tuning tests")\
0884         .getOrCreate()
0885     sc = spark.sparkContext
0886     globs['sc'] = sc
0887     globs['spark'] = spark
0888     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0889     spark.stop()
0890     if failure_count:
0891         sys.exit(-1)