0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Gradient Boosted Tree Classifier Example.
0020 """
0021 from __future__ import print_function
0022
0023
0024 from pyspark.ml import Pipeline
0025 from pyspark.ml.classification import GBTClassifier
0026 from pyspark.ml.feature import StringIndexer, VectorIndexer
0027 from pyspark.ml.evaluation import MulticlassClassificationEvaluator
0028
0029 from pyspark.sql import SparkSession
0030
0031 if __name__ == "__main__":
0032 spark = SparkSession\
0033 .builder\
0034 .appName("GradientBoostedTreeClassifierExample")\
0035 .getOrCreate()
0036
0037
0038
0039 data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
0040
0041
0042
0043 labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
0044
0045
0046 featureIndexer =\
0047 VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
0048
0049
0050 (trainingData, testData) = data.randomSplit([0.7, 0.3])
0051
0052
0053 gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)
0054
0055
0056 pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])
0057
0058
0059 model = pipeline.fit(trainingData)
0060
0061
0062 predictions = model.transform(testData)
0063
0064
0065 predictions.select("prediction", "indexedLabel", "features").show(5)
0066
0067
0068 evaluator = MulticlassClassificationEvaluator(
0069 labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
0070 accuracy = evaluator.evaluate(predictions)
0071 print("Test Error = %g" % (1.0 - accuracy))
0072
0073 gbtModel = model.stages[2]
0074 print(gbtModel)
0075
0076
0077 spark.stop()