0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Decision Tree Classification Example.
0020 """
0021 from __future__ import print_function
0022
0023 from pyspark import SparkContext
0024
0025 from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
0026 from pyspark.mllib.util import MLUtils
0027
0028
0029 if __name__ == "__main__":
0030
0031 sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
0032
0033
0034
0035 data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
0036
0037 (trainingData, testData) = data.randomSplit([0.7, 0.3])
0038
0039
0040
0041 model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
0042 impurity='gini', maxDepth=5, maxBins=32)
0043
0044
0045 predictions = model.predict(testData.map(lambda x: x.features))
0046 labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
0047 testErr = labelsAndPredictions.filter(
0048 lambda lp: lp[0] != lp[1]).count() / float(testData.count())
0049 print('Test Error = ' + str(testErr))
0050 print('Learned classification tree model:')
0051 print(model.toDebugString())
0052
0053
0054 model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
0055 sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
0056