|
||||
0001 # 0002 # Licensed to the Apache Software Foundation (ASF) under one or more 0003 # contributor license agreements. See the NOTICE file distributed with 0004 # this work for additional information regarding copyright ownership. 0005 # The ASF licenses this file to You under the Apache License, Version 2.0 0006 # (the "License"); you may not use this file except in compliance with 0007 # the License. You may obtain a copy of the License at 0008 # 0009 # http://www.apache.org/licenses/LICENSE-2.0 0010 # 0011 # Unless required by applicable law or agreed to in writing, software 0012 # distributed under the License is distributed on an "AS IS" BASIS, 0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 0014 # See the License for the specific language governing permissions and 0015 # limitations under the License. 0016 # 0017 0018 """ 0019 This example demonstrates applying TrainValidationSplit to split data 0020 and preform model selection. 0021 Run with: 0022 0023 bin/spark-submit examples/src/main/python/ml/train_validation_split.py 0024 """ 0025 # $example on$ 0026 from pyspark.ml.evaluation import RegressionEvaluator 0027 from pyspark.ml.regression import LinearRegression 0028 from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit 0029 # $example off$ 0030 from pyspark.sql import SparkSession 0031 0032 if __name__ == "__main__": 0033 spark = SparkSession\ 0034 .builder\ 0035 .appName("TrainValidationSplit")\ 0036 .getOrCreate() 0037 0038 # $example on$ 0039 # Prepare training and test data. 0040 data = spark.read.format("libsvm")\ 0041 .load("data/mllib/sample_linear_regression_data.txt") 0042 train, test = data.randomSplit([0.9, 0.1], seed=12345) 0043 0044 lr = LinearRegression(maxIter=10) 0045 0046 # We use a ParamGridBuilder to construct a grid of parameters to search over. 0047 # TrainValidationSplit will try all combinations of values and determine best model using 0048 # the evaluator. 0049 paramGrid = ParamGridBuilder()\ 0050 .addGrid(lr.regParam, [0.1, 0.01]) \ 0051 .addGrid(lr.fitIntercept, [False, True])\ 0052 .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\ 0053 .build() 0054 0055 # In this case the estimator is simply the linear regression. 0056 # A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. 0057 tvs = TrainValidationSplit(estimator=lr, 0058 estimatorParamMaps=paramGrid, 0059 evaluator=RegressionEvaluator(), 0060 # 80% of the data will be used for training, 20% for validation. 0061 trainRatio=0.8) 0062 0063 # Run TrainValidationSplit, and choose the best set of parameters. 0064 model = tvs.fit(train) 0065 0066 # Make predictions on test data. model is the model with combination of parameters 0067 # that performed best. 0068 model.transform(test)\ 0069 .select("features", "label", "prediction")\ 0070 .show() 0071 0072 # $example off$ 0073 spark.stop()
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.1.0 LXR engine. The LXR team |