|
||||
0001 # 0002 # Licensed to the Apache Software Foundation (ASF) under one or more 0003 # contributor license agreements. See the NOTICE file distributed with 0004 # this work for additional information regarding copyright ownership. 0005 # The ASF licenses this file to You under the Apache License, Version 2.0 0006 # (the "License"); you may not use this file except in compliance with 0007 # the License. You may obtain a copy of the License at 0008 # 0009 # http://www.apache.org/licenses/LICENSE-2.0 0010 # 0011 # Unless required by applicable law or agreed to in writing, software 0012 # distributed under the License is distributed on an "AS IS" BASIS, 0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 0014 # See the License for the specific language governing permissions and 0015 # limitations under the License. 0016 # 0017 0018 """ 0019 Estimator Transformer Param Example. 0020 """ 0021 from __future__ import print_function 0022 0023 # $example on$ 0024 from pyspark.ml.linalg import Vectors 0025 from pyspark.ml.classification import LogisticRegression 0026 # $example off$ 0027 from pyspark.sql import SparkSession 0028 0029 if __name__ == "__main__": 0030 spark = SparkSession\ 0031 .builder\ 0032 .appName("EstimatorTransformerParamExample")\ 0033 .getOrCreate() 0034 0035 # $example on$ 0036 # Prepare training data from a list of (label, features) tuples. 0037 training = spark.createDataFrame([ 0038 (1.0, Vectors.dense([0.0, 1.1, 0.1])), 0039 (0.0, Vectors.dense([2.0, 1.0, -1.0])), 0040 (0.0, Vectors.dense([2.0, 1.3, 1.0])), 0041 (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) 0042 0043 # Create a LogisticRegression instance. This instance is an Estimator. 0044 lr = LogisticRegression(maxIter=10, regParam=0.01) 0045 # Print out the parameters, documentation, and any default values. 0046 print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") 0047 0048 # Learn a LogisticRegression model. This uses the parameters stored in lr. 0049 model1 = lr.fit(training) 0050 0051 # Since model1 is a Model (i.e., a transformer produced by an Estimator), 0052 # we can view the parameters it used during fit(). 0053 # This prints the parameter (name: value) pairs, where names are unique IDs for this 0054 # LogisticRegression instance. 0055 print("Model 1 was fit using parameters: ") 0056 print(model1.extractParamMap()) 0057 0058 # We may alternatively specify parameters using a Python dictionary as a paramMap 0059 paramMap = {lr.maxIter: 20} 0060 paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. 0061 paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. 0062 0063 # You can combine paramMaps, which are python dictionaries. 0064 paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name 0065 paramMapCombined = paramMap.copy() 0066 paramMapCombined.update(paramMap2) 0067 0068 # Now learn a new model using the paramMapCombined parameters. 0069 # paramMapCombined overrides all parameters set earlier via lr.set* methods. 0070 model2 = lr.fit(training, paramMapCombined) 0071 print("Model 2 was fit using parameters: ") 0072 print(model2.extractParamMap()) 0073 0074 # Prepare test data 0075 test = spark.createDataFrame([ 0076 (1.0, Vectors.dense([-1.0, 1.5, 1.3])), 0077 (0.0, Vectors.dense([3.0, 2.0, -0.1])), 0078 (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) 0079 0080 # Make predictions on test data using the Transformer.transform() method. 0081 # LogisticRegression.transform will only use the 'features' column. 0082 # Note that model2.transform() outputs a "myProbability" column instead of the usual 0083 # 'probability' column since we renamed the lr.probabilityCol parameter previously. 0084 prediction = model2.transform(test) 0085 result = prediction.select("features", "label", "myProbability", "prediction") \ 0086 .collect() 0087 0088 for row in result: 0089 print("features=%s, label=%s -> prob=%s, prediction=%s" 0090 % (row.features, row.label, row.myProbability, row.prediction)) 0091 # $example off$ 0092 0093 spark.stop()
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.1.0 LXR engine. The LXR team |