Back to home page

OSCL-LXR

 
 

    


0001 ---
0002 layout: global
0003 title: "ML Tuning"
0004 displayTitle: "ML Tuning: model selection and hyperparameter tuning"
0005 license: |
0006   Licensed to the Apache Software Foundation (ASF) under one or more
0007   contributor license agreements.  See the NOTICE file distributed with
0008   this work for additional information regarding copyright ownership.
0009   The ASF licenses this file to You under the Apache License, Version 2.0
0010   (the "License"); you may not use this file except in compliance with
0011   the License.  You may obtain a copy of the License at
0012  
0013      http://www.apache.org/licenses/LICENSE-2.0
0014  
0015   Unless required by applicable law or agreed to in writing, software
0016   distributed under the License is distributed on an "AS IS" BASIS,
0017   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0018   See the License for the specific language governing permissions and
0019   limitations under the License.
0020 ---
0021 
0022 `\[
0023 \newcommand{\R}{\mathbb{R}}
0024 \newcommand{\E}{\mathbb{E}}
0025 \newcommand{\x}{\mathbf{x}}
0026 \newcommand{\y}{\mathbf{y}}
0027 \newcommand{\wv}{\mathbf{w}}
0028 \newcommand{\av}{\mathbf{\alpha}}
0029 \newcommand{\bv}{\mathbf{b}}
0030 \newcommand{\N}{\mathbb{N}}
0031 \newcommand{\id}{\mathbf{I}}
0032 \newcommand{\ind}{\mathbf{1}}
0033 \newcommand{\0}{\mathbf{0}}
0034 \newcommand{\unit}{\mathbf{e}}
0035 \newcommand{\one}{\mathbf{1}}
0036 \newcommand{\zero}{\mathbf{0}}
0037 \]`
0038 
0039 This section describes how to use MLlib's tooling for tuning ML algorithms and Pipelines.
0040 Built-in Cross-Validation and other tooling allow users to optimize hyperparameters in algorithms and Pipelines.
0041 
0042 **Table of contents**
0043 
0044 * This will become a table of contents (this text will be scraped).
0045 {:toc}
0046 
0047 # Model selection (a.k.a. hyperparameter tuning)
0048 
0049 An important task in ML is *model selection*, or using data to find the best model or parameters for a given task.  This is also called *tuning*.
0050 Tuning may be done for individual `Estimator`s such as `LogisticRegression`, or for entire `Pipeline`s which include multiple algorithms, featurization, and other steps.  Users can tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately.
0051 
0052 MLlib supports model selection using tools such as [`CrossValidator`](api/scala/org/apache/spark/ml/tuning/CrossValidator.html) and [`TrainValidationSplit`](api/scala/org/apache/spark/ml/tuning/TrainValidationSplit.html).
0053 These tools require the following items:
0054 
0055 * [`Estimator`](api/scala/org/apache/spark/ml/Estimator.html): algorithm or `Pipeline` to tune
0056 * Set of `ParamMap`s: parameters to choose from, sometimes called a "parameter grid" to search over
0057 * [`Evaluator`](api/scala/org/apache/spark/ml/evaluation/Evaluator.html): metric to measure how well a fitted `Model` does on held-out test data
0058 
0059 At a high level, these model selection tools work as follows:
0060 
0061 * They split the input data into separate training and test datasets.
0062 * For each (training, test) pair, they iterate through the set of `ParamMap`s:
0063   * For each `ParamMap`, they fit the `Estimator` using those parameters, get the fitted `Model`, and evaluate the `Model`'s performance using the `Evaluator`.
0064 * They select the `Model` produced by the best-performing set of parameters.
0065 
0066 The `Evaluator` can be a [`RegressionEvaluator`](api/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.html)
0067 for regression problems, a [`BinaryClassificationEvaluator`](api/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.html)
0068 for binary data, a [`MulticlassClassificationEvaluator`](api/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.html)
0069 for multiclass problems, a [`MultilabelClassificationEvaluator`](api/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.html)
0070  for multi-label classifications, or a
0071 [`RankingEvaluator`](api/scala/org/apache/spark/ml/evaluation/RankingEvaluator.html) for ranking problems. The default metric used to
0072 choose the best `ParamMap` can be overridden by the `setMetricName` method in each of these evaluators.
0073 
0074 To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/org/apache/spark/ml/tuning/ParamGridBuilder.html) utility.
0075 By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit`.
0076 The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance.  Generally speaking, a value up to 10 should be sufficient for most clusters.
0077 
0078 # Cross-Validation
0079 
0080 `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets. E.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.  To evaluate a particular `ParamMap`, `CrossValidator` computes the average evaluation metric for the 3 `Model`s produced by fitting the `Estimator` on the 3 different (training, test) dataset pairs.
0081 
0082 After identifying the best `ParamMap`, `CrossValidator` finally re-fits the `Estimator` using the best `ParamMap` and the entire dataset.
0083 
0084 **Examples: model selection via cross-validation**
0085 
0086 The following example demonstrates using `CrossValidator` to select from a grid of parameters.
0087 
0088 Note that cross-validation over a grid of parameters is expensive.
0089 E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds.  This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained.
0090 In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common).
0091 In other words, using `CrossValidator` can be very expensive.
0092 However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.
0093 
0094 <div class="codetabs">
0095 
0096 <div data-lang="scala" markdown="1">
0097 
0098 Refer to the [`CrossValidator` Scala docs](api/scala/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API.
0099 
0100 {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
0101 </div>
0102 
0103 <div data-lang="java" markdown="1">
0104 
0105 Refer to the [`CrossValidator` Java docs](api/java/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API.
0106 
0107 {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
0108 </div>
0109 
0110 <div data-lang="python" markdown="1">
0111 
0112 Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.CrossValidator) for more details on the API.
0113 
0114 {% include_example python/ml/cross_validator.py %}
0115 </div>
0116 
0117 </div>
0118 
0119 # Train-Validation Split
0120 
0121 In addition to  `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning.
0122 `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in
0123  the case of `CrossValidator`. It is, therefore, less expensive,
0124  but will not produce as reliable results when the training dataset is not sufficiently large.
0125 
0126 Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair.
0127 It splits the dataset into these two parts using the `trainRatio` parameter. For example with `$trainRatio=0.75$`,
0128 `TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.
0129 
0130 Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
0131 
0132 **Examples: model selection via train validation split**
0133 
0134 <div class="codetabs">
0135 
0136 <div data-lang="scala" markdown="1">
0137 
0138 Refer to the [`TrainValidationSplit` Scala docs](api/scala/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API.
0139 
0140 {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
0141 </div>
0142 
0143 <div data-lang="java" markdown="1">
0144 
0145 Refer to the [`TrainValidationSplit` Java docs](api/java/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API.
0146 
0147 {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
0148 </div>
0149 
0150 <div data-lang="python" markdown="1">
0151 
0152 Refer to the [`TrainValidationSplit` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.TrainValidationSplit) for more details on the API.
0153 
0154 {% include_example python/ml/train_validation_split.py %}
0155 </div>
0156 
0157 </div>