0001 ---
0002 layout: global
0003 title: "ML Tuning"
0004 displayTitle: "ML Tuning: model selection and hyperparameter tuning"
0005 license: |
0006 Licensed to the Apache Software Foundation (ASF) under one or more
0007 contributor license agreements. See the NOTICE file distributed with
0008 this work for additional information regarding copyright ownership.
0009 The ASF licenses this file to You under the Apache License, Version 2.0
0010 (the "License"); you may not use this file except in compliance with
0011 the License. You may obtain a copy of the License at
0012
0013 http://www.apache.org/licenses/LICENSE-2.0
0014
0015 Unless required by applicable law or agreed to in writing, software
0016 distributed under the License is distributed on an "AS IS" BASIS,
0017 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0018 See the License for the specific language governing permissions and
0019 limitations under the License.
0020 ---
0021
0022 `\[
0023 \newcommand{\R}{\mathbb{R}}
0024 \newcommand{\E}{\mathbb{E}}
0025 \newcommand{\x}{\mathbf{x}}
0026 \newcommand{\y}{\mathbf{y}}
0027 \newcommand{\wv}{\mathbf{w}}
0028 \newcommand{\av}{\mathbf{\alpha}}
0029 \newcommand{\bv}{\mathbf{b}}
0030 \newcommand{\N}{\mathbb{N}}
0031 \newcommand{\id}{\mathbf{I}}
0032 \newcommand{\ind}{\mathbf{1}}
0033 \newcommand{\0}{\mathbf{0}}
0034 \newcommand{\unit}{\mathbf{e}}
0035 \newcommand{\one}{\mathbf{1}}
0036 \newcommand{\zero}{\mathbf{0}}
0037 \]`
0038
0039 This section describes how to use MLlib's tooling for tuning ML algorithms and Pipelines.
0040 Built-in Cross-Validation and other tooling allow users to optimize hyperparameters in algorithms and Pipelines.
0041
0042 **Table of contents**
0043
0044 * This will become a table of contents (this text will be scraped).
0045 {:toc}
0046
0047 # Model selection (a.k.a. hyperparameter tuning)
0048
0049 An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*.
0050 Tuning may be done for individual `Estimator`s such as `LogisticRegression`, or for entire `Pipeline`s which include multiple algorithms, featurization, and other steps. Users can tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately.
0051
0052 MLlib supports model selection using tools such as [`CrossValidator`](api/scala/org/apache/spark/ml/tuning/CrossValidator.html) and [`TrainValidationSplit`](api/scala/org/apache/spark/ml/tuning/TrainValidationSplit.html).
0053 These tools require the following items:
0054
0055 * [`Estimator`](api/scala/org/apache/spark/ml/Estimator.html): algorithm or `Pipeline` to tune
0056 * Set of `ParamMap`s: parameters to choose from, sometimes called a "parameter grid" to search over
0057 * [`Evaluator`](api/scala/org/apache/spark/ml/evaluation/Evaluator.html): metric to measure how well a fitted `Model` does on held-out test data
0058
0059 At a high level, these model selection tools work as follows:
0060
0061 * They split the input data into separate training and test datasets.
0062 * For each (training, test) pair, they iterate through the set of `ParamMap`s:
0063 * For each `ParamMap`, they fit the `Estimator` using those parameters, get the fitted `Model`, and evaluate the `Model`'s performance using the `Evaluator`.
0064 * They select the `Model` produced by the best-performing set of parameters.
0065
0066 The `Evaluator` can be a [`RegressionEvaluator`](api/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.html)
0067 for regression problems, a [`BinaryClassificationEvaluator`](api/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.html)
0068 for binary data, a [`MulticlassClassificationEvaluator`](api/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.html)
0069 for multiclass problems, a [`MultilabelClassificationEvaluator`](api/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.html)
0070 for multi-label classifications, or a
0071 [`RankingEvaluator`](api/scala/org/apache/spark/ml/evaluation/RankingEvaluator.html) for ranking problems. The default metric used to
0072 choose the best `ParamMap` can be overridden by the `setMetricName` method in each of these evaluators.
0073
0074 To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/org/apache/spark/ml/tuning/ParamGridBuilder.html) utility.
0075 By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit`.
0076 The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance. Generally speaking, a value up to 10 should be sufficient for most clusters.
0077
0078 # Cross-Validation
0079
0080 `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets. E.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular `ParamMap`, `CrossValidator` computes the average evaluation metric for the 3 `Model`s produced by fitting the `Estimator` on the 3 different (training, test) dataset pairs.
0081
0082 After identifying the best `ParamMap`, `CrossValidator` finally re-fits the `Estimator` using the best `ParamMap` and the entire dataset.
0083
0084 **Examples: model selection via cross-validation**
0085
0086 The following example demonstrates using `CrossValidator` to select from a grid of parameters.
0087
0088 Note that cross-validation over a grid of parameters is expensive.
0089 E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained.
0090 In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common).
0091 In other words, using `CrossValidator` can be very expensive.
0092 However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.
0093
0094 <div class="codetabs">
0095
0096 <div data-lang="scala" markdown="1">
0097
0098 Refer to the [`CrossValidator` Scala docs](api/scala/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API.
0099
0100 {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
0101 </div>
0102
0103 <div data-lang="java" markdown="1">
0104
0105 Refer to the [`CrossValidator` Java docs](api/java/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API.
0106
0107 {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
0108 </div>
0109
0110 <div data-lang="python" markdown="1">
0111
0112 Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.CrossValidator) for more details on the API.
0113
0114 {% include_example python/ml/cross_validator.py %}
0115 </div>
0116
0117 </div>
0118
0119 # Train-Validation Split
0120
0121 In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning.
0122 `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in
0123 the case of `CrossValidator`. It is, therefore, less expensive,
0124 but will not produce as reliable results when the training dataset is not sufficiently large.
0125
0126 Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair.
0127 It splits the dataset into these two parts using the `trainRatio` parameter. For example with `$trainRatio=0.75$`,
0128 `TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.
0129
0130 Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
0131
0132 **Examples: model selection via train validation split**
0133
0134 <div class="codetabs">
0135
0136 <div data-lang="scala" markdown="1">
0137
0138 Refer to the [`TrainValidationSplit` Scala docs](api/scala/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API.
0139
0140 {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
0141 </div>
0142
0143 <div data-lang="java" markdown="1">
0144
0145 Refer to the [`TrainValidationSplit` Java docs](api/java/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API.
0146
0147 {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
0148 </div>
0149
0150 <div data-lang="python" markdown="1">
0151
0152 Refer to the [`TrainValidationSplit` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.TrainValidationSplit) for more details on the API.
0153
0154 {% include_example python/ml/train_validation_split.py %}
0155 </div>
0156
0157 </div>