0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 A simple example demonstrating model selection using CrossValidator.
0020 This example also demonstrates how Pipelines are Estimators.
0021 Run with:
0022
0023 bin/spark-submit examples/src/main/python/ml/cross_validator.py
0024 """
0025 from __future__ import print_function
0026
0027
0028 from pyspark.ml import Pipeline
0029 from pyspark.ml.classification import LogisticRegression
0030 from pyspark.ml.evaluation import BinaryClassificationEvaluator
0031 from pyspark.ml.feature import HashingTF, Tokenizer
0032 from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
0033
0034 from pyspark.sql import SparkSession
0035
0036 if __name__ == "__main__":
0037 spark = SparkSession\
0038 .builder\
0039 .appName("CrossValidatorExample")\
0040 .getOrCreate()
0041
0042
0043
0044 training = spark.createDataFrame([
0045 (0, "a b c d e spark", 1.0),
0046 (1, "b d", 0.0),
0047 (2, "spark f g h", 1.0),
0048 (3, "hadoop mapreduce", 0.0),
0049 (4, "b spark who", 1.0),
0050 (5, "g d a y", 0.0),
0051 (6, "spark fly", 1.0),
0052 (7, "was mapreduce", 0.0),
0053 (8, "e spark program", 1.0),
0054 (9, "a e c l", 0.0),
0055 (10, "spark compile", 1.0),
0056 (11, "hadoop software", 0.0)
0057 ], ["id", "text", "label"])
0058
0059
0060 tokenizer = Tokenizer(inputCol="text", outputCol="words")
0061 hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
0062 lr = LogisticRegression(maxIter=10)
0063 pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
0064
0065
0066
0067
0068
0069
0070
0071 paramGrid = ParamGridBuilder() \
0072 .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
0073 .addGrid(lr.regParam, [0.1, 0.01]) \
0074 .build()
0075
0076 crossval = CrossValidator(estimator=pipeline,
0077 estimatorParamMaps=paramGrid,
0078 evaluator=BinaryClassificationEvaluator(),
0079 numFolds=2)
0080
0081
0082 cvModel = crossval.fit(training)
0083
0084
0085 test = spark.createDataFrame([
0086 (4, "spark i j k"),
0087 (5, "l m n"),
0088 (6, "mapreduce spark"),
0089 (7, "apache hadoop")
0090 ], ["id", "text"])
0091
0092
0093 prediction = cvModel.transform(test)
0094 selected = prediction.select("id", "text", "probability", "prediction")
0095 for row in selected.collect():
0096 print(row)
0097
0098
0099 spark.stop()