Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 A simple example demonstrating model selection using CrossValidator.
0020 This example also demonstrates how Pipelines are Estimators.
0021 Run with:
0022 
0023   bin/spark-submit examples/src/main/python/ml/cross_validator.py
0024 """
0025 from __future__ import print_function
0026 
0027 # $example on$
0028 from pyspark.ml import Pipeline
0029 from pyspark.ml.classification import LogisticRegression
0030 from pyspark.ml.evaluation import BinaryClassificationEvaluator
0031 from pyspark.ml.feature import HashingTF, Tokenizer
0032 from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
0033 # $example off$
0034 from pyspark.sql import SparkSession
0035 
0036 if __name__ == "__main__":
0037     spark = SparkSession\
0038         .builder\
0039         .appName("CrossValidatorExample")\
0040         .getOrCreate()
0041 
0042     # $example on$
0043     # Prepare training documents, which are labeled.
0044     training = spark.createDataFrame([
0045         (0, "a b c d e spark", 1.0),
0046         (1, "b d", 0.0),
0047         (2, "spark f g h", 1.0),
0048         (3, "hadoop mapreduce", 0.0),
0049         (4, "b spark who", 1.0),
0050         (5, "g d a y", 0.0),
0051         (6, "spark fly", 1.0),
0052         (7, "was mapreduce", 0.0),
0053         (8, "e spark program", 1.0),
0054         (9, "a e c l", 0.0),
0055         (10, "spark compile", 1.0),
0056         (11, "hadoop software", 0.0)
0057     ], ["id", "text", "label"])
0058 
0059     # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
0060     tokenizer = Tokenizer(inputCol="text", outputCol="words")
0061     hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
0062     lr = LogisticRegression(maxIter=10)
0063     pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
0064 
0065     # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
0066     # This will allow us to jointly choose parameters for all Pipeline stages.
0067     # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
0068     # We use a ParamGridBuilder to construct a grid of parameters to search over.
0069     # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
0070     # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
0071     paramGrid = ParamGridBuilder() \
0072         .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
0073         .addGrid(lr.regParam, [0.1, 0.01]) \
0074         .build()
0075 
0076     crossval = CrossValidator(estimator=pipeline,
0077                               estimatorParamMaps=paramGrid,
0078                               evaluator=BinaryClassificationEvaluator(),
0079                               numFolds=2)  # use 3+ folds in practice
0080 
0081     # Run cross-validation, and choose the best set of parameters.
0082     cvModel = crossval.fit(training)
0083 
0084     # Prepare test documents, which are unlabeled.
0085     test = spark.createDataFrame([
0086         (4, "spark i j k"),
0087         (5, "l m n"),
0088         (6, "mapreduce spark"),
0089         (7, "apache hadoop")
0090     ], ["id", "text"])
0091 
0092     # Make predictions on test documents. cvModel uses the best model found (lrModel).
0093     prediction = cvModel.transform(test)
0094     selected = prediction.select("id", "text", "probability", "prediction")
0095     for row in selected.collect():
0096         print(row)
0097     # $example off$
0098 
0099     spark.stop()