0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Random Forest Classifier Example.
0020 """
0021 from __future__ import print_function
0022
0023
0024 from pyspark.ml import Pipeline
0025 from pyspark.ml.classification import RandomForestClassifier
0026 from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
0027 from pyspark.ml.evaluation import MulticlassClassificationEvaluator
0028
0029 from pyspark.sql import SparkSession
0030
0031 if __name__ == "__main__":
0032 spark = SparkSession\
0033 .builder\
0034 .appName("RandomForestClassifierExample")\
0035 .getOrCreate()
0036
0037
0038
0039 data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
0040
0041
0042
0043 labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
0044
0045
0046
0047 featureIndexer =\
0048 VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
0049
0050
0051 (trainingData, testData) = data.randomSplit([0.7, 0.3])
0052
0053
0054 rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10)
0055
0056
0057 labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",
0058 labels=labelIndexer.labels)
0059
0060
0061 pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter])
0062
0063
0064 model = pipeline.fit(trainingData)
0065
0066
0067 predictions = model.transform(testData)
0068
0069
0070 predictions.select("predictedLabel", "label", "features").show(5)
0071
0072
0073 evaluator = MulticlassClassificationEvaluator(
0074 labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
0075 accuracy = evaluator.evaluate(predictions)
0076 print("Test Error = %g" % (1.0 - accuracy))
0077
0078 rfModel = model.stages[2]
0079 print(rfModel)
0080
0081
0082 spark.stop()