0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Gradient Boosted Trees Classification Example.
0020 """
0021 from __future__ import print_function
0022
0023 from pyspark import SparkContext
0024
0025 from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
0026 from pyspark.mllib.util import MLUtils
0027
0028
0029 if __name__ == "__main__":
0030 sc = SparkContext(appName="PythonGradientBoostedTreesClassificationExample")
0031
0032
0033 data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
0034
0035 (trainingData, testData) = data.randomSplit([0.7, 0.3])
0036
0037
0038
0039
0040 model = GradientBoostedTrees.trainClassifier(trainingData,
0041 categoricalFeaturesInfo={}, numIterations=3)
0042
0043
0044 predictions = model.predict(testData.map(lambda x: x.features))
0045 labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
0046 testErr = labelsAndPredictions.filter(
0047 lambda lp: lp[0] != lp[1]).count() / float(testData.count())
0048 print('Test Error = ' + str(testErr))
0049 print('Learned classification GBT model:')
0050 print(model.toDebugString())
0051
0052
0053 model.save(sc, "target/tmp/myGradientBoostingClassificationModel")
0054 sameModel = GradientBoostedTreesModel.load(sc,
0055 "target/tmp/myGradientBoostingClassificationModel")
0056