0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 Decision Tree Regression Example.
0020 """
0021 from __future__ import print_function
0022
0023 from pyspark import SparkContext
0024
0025 from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
0026 from pyspark.mllib.util import MLUtils
0027
0028
0029 if __name__ == "__main__":
0030
0031 sc = SparkContext(appName="PythonDecisionTreeRegressionExample")
0032
0033
0034
0035 data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
0036
0037 (trainingData, testData) = data.randomSplit([0.7, 0.3])
0038
0039
0040
0041 model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
0042 impurity='variance', maxDepth=5, maxBins=32)
0043
0044
0045 predictions = model.predict(testData.map(lambda x: x.features))
0046 labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
0047 testMSE = labelsAndPredictions.map(lambda lp: (lp[0] - lp[1]) * (lp[0] - lp[1])).sum() /\
0048 float(testData.count())
0049 print('Test Mean Squared Error = ' + str(testMSE))
0050 print('Learned regression tree model:')
0051 print(model.toDebugString())
0052
0053
0054 model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
0055 sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
0056