0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import absolute_import
0019
0020 import sys
0021 import random
0022
0023 from pyspark import RDD, since
0024 from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
0025 from pyspark.mllib.linalg import _convert_to_vector
0026 from pyspark.mllib.regression import LabeledPoint
0027 from pyspark.mllib.util import JavaLoader, JavaSaveable
0028
0029 __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
0030 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees']
0031
0032
0033 class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
0034 """TreeEnsembleModel
0035
0036 .. versionadded:: 1.3.0
0037 """
0038 @since("1.3.0")
0039 def predict(self, x):
0040 """
0041 Predict values for a single data point or an RDD of points using
0042 the model trained.
0043
0044 .. note:: In Python, predict cannot currently be used within an RDD
0045 transformation or action.
0046 Call predict directly on the RDD instead.
0047 """
0048 if isinstance(x, RDD):
0049 return self.call("predict", x.map(_convert_to_vector))
0050
0051 else:
0052 return self.call("predict", _convert_to_vector(x))
0053
0054 @since("1.3.0")
0055 def numTrees(self):
0056 """
0057 Get number of trees in ensemble.
0058 """
0059 return self.call("numTrees")
0060
0061 @since("1.3.0")
0062 def totalNumNodes(self):
0063 """
0064 Get total number of nodes, summed over all trees in the ensemble.
0065 """
0066 return self.call("totalNumNodes")
0067
0068 def __repr__(self):
0069 """ Summary of model """
0070 return self._java_model.toString()
0071
0072 @since("1.3.0")
0073 def toDebugString(self):
0074 """ Full model """
0075 return self._java_model.toDebugString()
0076
0077
0078 class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
0079 """
0080 A decision tree model for classification or regression.
0081
0082 .. versionadded:: 1.1.0
0083 """
0084 @since("1.1.0")
0085 def predict(self, x):
0086 """
0087 Predict the label of one or more examples.
0088
0089 .. note:: In Python, predict cannot currently be used within an RDD
0090 transformation or action.
0091 Call predict directly on the RDD instead.
0092
0093 :param x:
0094 Data point (feature vector), or an RDD of data points (feature
0095 vectors).
0096 """
0097 if isinstance(x, RDD):
0098 return self.call("predict", x.map(_convert_to_vector))
0099
0100 else:
0101 return self.call("predict", _convert_to_vector(x))
0102
0103 @since("1.1.0")
0104 def numNodes(self):
0105 """Get number of nodes in tree, including leaf nodes."""
0106 return self._java_model.numNodes()
0107
0108 @since("1.1.0")
0109 def depth(self):
0110 """
0111 Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0112 means 1 internal node + 2 leaf nodes).
0113 """
0114 return self._java_model.depth()
0115
0116 def __repr__(self):
0117 """ summary of model. """
0118 return self._java_model.toString()
0119
0120 @since("1.2.0")
0121 def toDebugString(self):
0122 """ full model. """
0123 return self._java_model.toDebugString()
0124
0125 @classmethod
0126 def _java_loader_class(cls):
0127 return "org.apache.spark.mllib.tree.model.DecisionTreeModel"
0128
0129
0130 class DecisionTree(object):
0131 """
0132 Learning algorithm for a decision tree model for classification or
0133 regression.
0134
0135 .. versionadded:: 1.1.0
0136 """
0137
0138 @classmethod
0139 def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32,
0140 minInstancesPerNode=1, minInfoGain=0.0):
0141 first = data.first()
0142 assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
0143 model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features,
0144 impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
0145 return DecisionTreeModel(model)
0146
0147 @classmethod
0148 @since("1.1.0")
0149 def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo,
0150 impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
0151 minInfoGain=0.0):
0152 """
0153 Train a decision tree model for classification.
0154
0155 :param data:
0156 Training data: RDD of LabeledPoint. Labels should take values
0157 {0, 1, ..., numClasses-1}.
0158 :param numClasses:
0159 Number of classes for classification.
0160 :param categoricalFeaturesInfo:
0161 Map storing arity of categorical features. An entry (n -> k)
0162 indicates that feature n is categorical with k categories
0163 indexed from 0: {0, 1, ..., k-1}.
0164 :param impurity:
0165 Criterion used for information gain calculation.
0166 Supported values: "gini" or "entropy".
0167 (default: "gini")
0168 :param maxDepth:
0169 Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0170 means 1 internal node + 2 leaf nodes).
0171 (default: 5)
0172 :param maxBins:
0173 Number of bins used for finding splits at each node.
0174 (default: 32)
0175 :param minInstancesPerNode:
0176 Minimum number of instances required at child nodes to create
0177 the parent split.
0178 (default: 1)
0179 :param minInfoGain:
0180 Minimum info gain required to create a split.
0181 (default: 0.0)
0182 :return:
0183 DecisionTreeModel.
0184
0185 Example usage:
0186
0187 >>> from numpy import array
0188 >>> from pyspark.mllib.regression import LabeledPoint
0189 >>> from pyspark.mllib.tree import DecisionTree
0190 >>>
0191 >>> data = [
0192 ... LabeledPoint(0.0, [0.0]),
0193 ... LabeledPoint(1.0, [1.0]),
0194 ... LabeledPoint(1.0, [2.0]),
0195 ... LabeledPoint(1.0, [3.0])
0196 ... ]
0197 >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
0198 >>> print(model)
0199 DecisionTreeModel classifier of depth 1 with 3 nodes
0200
0201 >>> print(model.toDebugString())
0202 DecisionTreeModel classifier of depth 1 with 3 nodes
0203 If (feature 0 <= 0.5)
0204 Predict: 0.0
0205 Else (feature 0 > 0.5)
0206 Predict: 1.0
0207 <BLANKLINE>
0208 >>> model.predict(array([1.0]))
0209 1.0
0210 >>> model.predict(array([0.0]))
0211 0.0
0212 >>> rdd = sc.parallelize([[1.0], [0.0]])
0213 >>> model.predict(rdd).collect()
0214 [1.0, 0.0]
0215 """
0216 return cls._train(data, "classification", numClasses, categoricalFeaturesInfo,
0217 impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
0218
0219 @classmethod
0220 @since("1.1.0")
0221 def trainRegressor(cls, data, categoricalFeaturesInfo,
0222 impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
0223 minInfoGain=0.0):
0224 """
0225 Train a decision tree model for regression.
0226
0227 :param data:
0228 Training data: RDD of LabeledPoint. Labels are real numbers.
0229 :param categoricalFeaturesInfo:
0230 Map storing arity of categorical features. An entry (n -> k)
0231 indicates that feature n is categorical with k categories
0232 indexed from 0: {0, 1, ..., k-1}.
0233 :param impurity:
0234 Criterion used for information gain calculation.
0235 The only supported value for regression is "variance".
0236 (default: "variance")
0237 :param maxDepth:
0238 Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0239 means 1 internal node + 2 leaf nodes).
0240 (default: 5)
0241 :param maxBins:
0242 Number of bins used for finding splits at each node.
0243 (default: 32)
0244 :param minInstancesPerNode:
0245 Minimum number of instances required at child nodes to create
0246 the parent split.
0247 (default: 1)
0248 :param minInfoGain:
0249 Minimum info gain required to create a split.
0250 (default: 0.0)
0251 :return:
0252 DecisionTreeModel.
0253
0254 Example usage:
0255
0256 >>> from pyspark.mllib.regression import LabeledPoint
0257 >>> from pyspark.mllib.tree import DecisionTree
0258 >>> from pyspark.mllib.linalg import SparseVector
0259 >>>
0260 >>> sparse_data = [
0261 ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
0262 ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
0263 ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
0264 ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
0265 ... ]
0266 >>>
0267 >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
0268 >>> model.predict(SparseVector(2, {1: 1.0}))
0269 1.0
0270 >>> model.predict(SparseVector(2, {1: 0.0}))
0271 0.0
0272 >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
0273 >>> model.predict(rdd).collect()
0274 [1.0, 0.0]
0275 """
0276 return cls._train(data, "regression", 0, categoricalFeaturesInfo,
0277 impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
0278
0279
0280 @inherit_doc
0281 class RandomForestModel(TreeEnsembleModel, JavaLoader):
0282 """
0283 Represents a random forest model.
0284
0285 .. versionadded:: 1.2.0
0286 """
0287
0288 @classmethod
0289 def _java_loader_class(cls):
0290 return "org.apache.spark.mllib.tree.model.RandomForestModel"
0291
0292
0293 class RandomForest(object):
0294 """
0295 Learning algorithm for a random forest model for classification or
0296 regression.
0297
0298 .. versionadded:: 1.2.0
0299 """
0300
0301 supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
0302
0303 @classmethod
0304 def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees,
0305 featureSubsetStrategy, impurity, maxDepth, maxBins, seed):
0306 first = data.first()
0307 assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
0308 if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies:
0309 raise ValueError("unsupported featureSubsetStrategy: %s" % featureSubsetStrategy)
0310 if seed is None:
0311 seed = random.randint(0, 1 << 30)
0312 model = callMLlibFunc("trainRandomForestModel", data, algo, numClasses,
0313 categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity,
0314 maxDepth, maxBins, seed)
0315 return RandomForestModel(model)
0316
0317 @classmethod
0318 @since("1.2.0")
0319 def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
0320 featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32,
0321 seed=None):
0322 """
0323 Train a random forest model for binary or multiclass
0324 classification.
0325
0326 :param data:
0327 Training dataset: RDD of LabeledPoint. Labels should take values
0328 {0, 1, ..., numClasses-1}.
0329 :param numClasses:
0330 Number of classes for classification.
0331 :param categoricalFeaturesInfo:
0332 Map storing arity of categorical features. An entry (n -> k)
0333 indicates that feature n is categorical with k categories
0334 indexed from 0: {0, 1, ..., k-1}.
0335 :param numTrees:
0336 Number of trees in the random forest.
0337 :param featureSubsetStrategy:
0338 Number of features to consider for splits at each node.
0339 Supported values: "auto", "all", "sqrt", "log2", "onethird".
0340 If "auto" is set, this parameter is set based on numTrees:
0341 if numTrees == 1, set to "all";
0342 if numTrees > 1 (forest) set to "sqrt".
0343 (default: "auto")
0344 :param impurity:
0345 Criterion used for information gain calculation.
0346 Supported values: "gini" or "entropy".
0347 (default: "gini")
0348 :param maxDepth:
0349 Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0350 means 1 internal node + 2 leaf nodes).
0351 (default: 4)
0352 :param maxBins:
0353 Maximum number of bins used for splitting features.
0354 (default: 32)
0355 :param seed:
0356 Random seed for bootstrapping and choosing feature subsets.
0357 Set as None to generate seed based on system time.
0358 (default: None)
0359 :return:
0360 RandomForestModel that can be used for prediction.
0361
0362 Example usage:
0363
0364 >>> from pyspark.mllib.regression import LabeledPoint
0365 >>> from pyspark.mllib.tree import RandomForest
0366 >>>
0367 >>> data = [
0368 ... LabeledPoint(0.0, [0.0]),
0369 ... LabeledPoint(0.0, [1.0]),
0370 ... LabeledPoint(1.0, [2.0]),
0371 ... LabeledPoint(1.0, [3.0])
0372 ... ]
0373 >>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42)
0374 >>> model.numTrees()
0375 3
0376 >>> model.totalNumNodes()
0377 7
0378 >>> print(model)
0379 TreeEnsembleModel classifier with 3 trees
0380 <BLANKLINE>
0381 >>> print(model.toDebugString())
0382 TreeEnsembleModel classifier with 3 trees
0383 <BLANKLINE>
0384 Tree 0:
0385 Predict: 1.0
0386 Tree 1:
0387 If (feature 0 <= 1.5)
0388 Predict: 0.0
0389 Else (feature 0 > 1.5)
0390 Predict: 1.0
0391 Tree 2:
0392 If (feature 0 <= 1.5)
0393 Predict: 0.0
0394 Else (feature 0 > 1.5)
0395 Predict: 1.0
0396 <BLANKLINE>
0397 >>> model.predict([2.0])
0398 1.0
0399 >>> model.predict([0.0])
0400 0.0
0401 >>> rdd = sc.parallelize([[3.0], [1.0]])
0402 >>> model.predict(rdd).collect()
0403 [1.0, 0.0]
0404 """
0405 return cls._train(data, "classification", numClasses,
0406 categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity,
0407 maxDepth, maxBins, seed)
0408
0409 @classmethod
0410 @since("1.2.0")
0411 def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto",
0412 impurity="variance", maxDepth=4, maxBins=32, seed=None):
0413 """
0414 Train a random forest model for regression.
0415
0416 :param data:
0417 Training dataset: RDD of LabeledPoint. Labels are real numbers.
0418 :param categoricalFeaturesInfo:
0419 Map storing arity of categorical features. An entry (n -> k)
0420 indicates that feature n is categorical with k categories
0421 indexed from 0: {0, 1, ..., k-1}.
0422 :param numTrees:
0423 Number of trees in the random forest.
0424 :param featureSubsetStrategy:
0425 Number of features to consider for splits at each node.
0426 Supported values: "auto", "all", "sqrt", "log2", "onethird".
0427 If "auto" is set, this parameter is set based on numTrees:
0428 if numTrees == 1, set to "all";
0429 if numTrees > 1 (forest) set to "onethird" for regression.
0430 (default: "auto")
0431 :param impurity:
0432 Criterion used for information gain calculation.
0433 The only supported value for regression is "variance".
0434 (default: "variance")
0435 :param maxDepth:
0436 Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0437 means 1 internal node + 2 leaf nodes).
0438 (default: 4)
0439 :param maxBins:
0440 Maximum number of bins used for splitting features.
0441 (default: 32)
0442 :param seed:
0443 Random seed for bootstrapping and choosing feature subsets.
0444 Set as None to generate seed based on system time.
0445 (default: None)
0446 :return:
0447 RandomForestModel that can be used for prediction.
0448
0449 Example usage:
0450
0451 >>> from pyspark.mllib.regression import LabeledPoint
0452 >>> from pyspark.mllib.tree import RandomForest
0453 >>> from pyspark.mllib.linalg import SparseVector
0454 >>>
0455 >>> sparse_data = [
0456 ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
0457 ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
0458 ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
0459 ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
0460 ... ]
0461 >>>
0462 >>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42)
0463 >>> model.numTrees()
0464 2
0465 >>> model.totalNumNodes()
0466 4
0467 >>> model.predict(SparseVector(2, {1: 1.0}))
0468 1.0
0469 >>> model.predict(SparseVector(2, {0: 1.0}))
0470 0.5
0471 >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])
0472 >>> model.predict(rdd).collect()
0473 [1.0, 0.5]
0474 """
0475 return cls._train(data, "regression", 0, categoricalFeaturesInfo, numTrees,
0476 featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
0477
0478
0479 @inherit_doc
0480 class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
0481 """
0482 Represents a gradient-boosted tree model.
0483
0484 .. versionadded:: 1.3.0
0485 """
0486
0487 @classmethod
0488 def _java_loader_class(cls):
0489 return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
0490
0491
0492 class GradientBoostedTrees(object):
0493 """
0494 Learning algorithm for a gradient boosted trees model for
0495 classification or regression.
0496
0497 .. versionadded:: 1.3.0
0498 """
0499
0500 @classmethod
0501 def _train(cls, data, algo, categoricalFeaturesInfo,
0502 loss, numIterations, learningRate, maxDepth, maxBins):
0503 first = data.first()
0504 assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
0505 model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo,
0506 loss, numIterations, learningRate, maxDepth, maxBins)
0507 return GradientBoostedTreesModel(model)
0508
0509 @classmethod
0510 @since("1.3.0")
0511 def trainClassifier(cls, data, categoricalFeaturesInfo,
0512 loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3,
0513 maxBins=32):
0514 """
0515 Train a gradient-boosted trees model for classification.
0516
0517 :param data:
0518 Training dataset: RDD of LabeledPoint. Labels should take values
0519 {0, 1}.
0520 :param categoricalFeaturesInfo:
0521 Map storing arity of categorical features. An entry (n -> k)
0522 indicates that feature n is categorical with k categories
0523 indexed from 0: {0, 1, ..., k-1}.
0524 :param loss:
0525 Loss function used for minimization during gradient boosting.
0526 Supported values: "logLoss", "leastSquaresError",
0527 "leastAbsoluteError".
0528 (default: "logLoss")
0529 :param numIterations:
0530 Number of iterations of boosting.
0531 (default: 100)
0532 :param learningRate:
0533 Learning rate for shrinking the contribution of each estimator.
0534 The learning rate should be between in the interval (0, 1].
0535 (default: 0.1)
0536 :param maxDepth:
0537 Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0538 means 1 internal node + 2 leaf nodes).
0539 (default: 3)
0540 :param maxBins:
0541 Maximum number of bins used for splitting features. DecisionTree
0542 requires maxBins >= max categories.
0543 (default: 32)
0544 :return:
0545 GradientBoostedTreesModel that can be used for prediction.
0546
0547 Example usage:
0548
0549 >>> from pyspark.mllib.regression import LabeledPoint
0550 >>> from pyspark.mllib.tree import GradientBoostedTrees
0551 >>>
0552 >>> data = [
0553 ... LabeledPoint(0.0, [0.0]),
0554 ... LabeledPoint(0.0, [1.0]),
0555 ... LabeledPoint(1.0, [2.0]),
0556 ... LabeledPoint(1.0, [3.0])
0557 ... ]
0558 >>>
0559 >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10)
0560 >>> model.numTrees()
0561 10
0562 >>> model.totalNumNodes()
0563 30
0564 >>> print(model) # it already has newline
0565 TreeEnsembleModel classifier with 10 trees
0566 <BLANKLINE>
0567 >>> model.predict([2.0])
0568 1.0
0569 >>> model.predict([0.0])
0570 0.0
0571 >>> rdd = sc.parallelize([[2.0], [0.0]])
0572 >>> model.predict(rdd).collect()
0573 [1.0, 0.0]
0574 """
0575 return cls._train(data, "classification", categoricalFeaturesInfo,
0576 loss, numIterations, learningRate, maxDepth, maxBins)
0577
0578 @classmethod
0579 @since("1.3.0")
0580 def trainRegressor(cls, data, categoricalFeaturesInfo,
0581 loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3,
0582 maxBins=32):
0583 """
0584 Train a gradient-boosted trees model for regression.
0585
0586 :param data:
0587 Training dataset: RDD of LabeledPoint. Labels are real numbers.
0588 :param categoricalFeaturesInfo:
0589 Map storing arity of categorical features. An entry (n -> k)
0590 indicates that feature n is categorical with k categories
0591 indexed from 0: {0, 1, ..., k-1}.
0592 :param loss:
0593 Loss function used for minimization during gradient boosting.
0594 Supported values: "logLoss", "leastSquaresError",
0595 "leastAbsoluteError".
0596 (default: "leastSquaresError")
0597 :param numIterations:
0598 Number of iterations of boosting.
0599 (default: 100)
0600 :param learningRate:
0601 Learning rate for shrinking the contribution of each estimator.
0602 The learning rate should be between in the interval (0, 1].
0603 (default: 0.1)
0604 :param maxDepth:
0605 Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1
0606 means 1 internal node + 2 leaf nodes).
0607 (default: 3)
0608 :param maxBins:
0609 Maximum number of bins used for splitting features. DecisionTree
0610 requires maxBins >= max categories.
0611 (default: 32)
0612 :return:
0613 GradientBoostedTreesModel that can be used for prediction.
0614
0615 Example usage:
0616
0617 >>> from pyspark.mllib.regression import LabeledPoint
0618 >>> from pyspark.mllib.tree import GradientBoostedTrees
0619 >>> from pyspark.mllib.linalg import SparseVector
0620 >>>
0621 >>> sparse_data = [
0622 ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
0623 ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
0624 ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
0625 ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
0626 ... ]
0627 >>>
0628 >>> data = sc.parallelize(sparse_data)
0629 >>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10)
0630 >>> model.numTrees()
0631 10
0632 >>> model.totalNumNodes()
0633 12
0634 >>> model.predict(SparseVector(2, {1: 1.0}))
0635 1.0
0636 >>> model.predict(SparseVector(2, {0: 1.0}))
0637 0.0
0638 >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])
0639 >>> model.predict(rdd).collect()
0640 [1.0, 0.0]
0641 """
0642 return cls._train(data, "regression", categoricalFeaturesInfo,
0643 loss, numIterations, learningRate, maxDepth, maxBins)
0644
0645
0646 def _test():
0647 import doctest
0648 globs = globals().copy()
0649 from pyspark.sql import SparkSession
0650 spark = SparkSession.builder\
0651 .master("local[4]")\
0652 .appName("mllib.tree tests")\
0653 .getOrCreate()
0654 globs['sc'] = spark.sparkContext
0655 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0656 spark.stop()
0657 if failure_count:
0658 sys.exit(-1)
0659
0660 if __name__ == "__main__":
0661 _test()