0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from pyspark import since, keyword_only
0019 from pyspark.ml.param.shared import *
0020 from pyspark.ml.util import *
0021 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
0022 JavaPredictor, JavaPredictionModel
0023 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0024
0025
0026 @inherit_doc
0027 class _DecisionTreeModel(JavaPredictionModel):
0028 """
0029 Abstraction for Decision Tree models.
0030
0031 .. versionadded:: 1.5.0
0032 """
0033
0034 @property
0035 @since("1.5.0")
0036 def numNodes(self):
0037 """Return number of nodes of the decision tree."""
0038 return self._call_java("numNodes")
0039
0040 @property
0041 @since("1.5.0")
0042 def depth(self):
0043 """Return depth of the decision tree."""
0044 return self._call_java("depth")
0045
0046 @property
0047 @since("2.0.0")
0048 def toDebugString(self):
0049 """Full description of model."""
0050 return self._call_java("toDebugString")
0051
0052 @since("3.0.0")
0053 def predictLeaf(self, value):
0054 """
0055 Predict the indices of the leaves corresponding to the feature vector.
0056 """
0057 return self._call_java("predictLeaf", value)
0058
0059
0060 class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
0061 """
0062 Mixin for Decision Tree parameters.
0063 """
0064
0065 leafCol = Param(Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " +
0066 "index of each instance in each tree by preorder.",
0067 typeConverter=TypeConverters.toString)
0068
0069 maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " +
0070 "depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
0071 typeConverter=TypeConverters.toInt)
0072
0073 maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " +
0074 "features. Must be >=2 and >= number of categories for any categorical " +
0075 "feature.", typeConverter=TypeConverters.toInt)
0076
0077 minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of " +
0078 "instances each child must have after split. If a split causes " +
0079 "the left or right child to have fewer than " +
0080 "minInstancesPerNode, the split will be discarded as invalid. " +
0081 "Should be >= 1.", typeConverter=TypeConverters.toInt)
0082
0083 minWeightFractionPerNode = Param(Params._dummy(), "minWeightFractionPerNode", "Minimum "
0084 "fraction of the weighted sample count that each child "
0085 "must have after split. If a split causes the fraction "
0086 "of the total weight in the left or right child to be "
0087 "less than minWeightFractionPerNode, the split will be "
0088 "discarded as invalid. Should be in interval [0.0, 0.5).",
0089 typeConverter=TypeConverters.toFloat)
0090
0091 minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split " +
0092 "to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
0093
0094 maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " +
0095 "histogram aggregation. If too small, then 1 node will be split per " +
0096 "iteration, and its aggregates may exceed this size.",
0097 typeConverter=TypeConverters.toInt)
0098
0099 cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " +
0100 "trees to executors to match instances with nodes. If true, the " +
0101 "algorithm will cache node IDs for each instance. Caching can speed " +
0102 "up training of deeper trees. Users can set how often should the cache " +
0103 "be checkpointed or disable it by setting checkpointInterval.",
0104 typeConverter=TypeConverters.toBoolean)
0105
0106 def __init__(self):
0107 super(_DecisionTreeParams, self).__init__()
0108
0109 def setLeafCol(self, value):
0110 """
0111 Sets the value of :py:attr:`leafCol`.
0112 """
0113 return self._set(leafCol=value)
0114
0115 def getLeafCol(self):
0116 """
0117 Gets the value of leafCol or its default value.
0118 """
0119 return self.getOrDefault(self.leafCol)
0120
0121 def getMaxDepth(self):
0122 """
0123 Gets the value of maxDepth or its default value.
0124 """
0125 return self.getOrDefault(self.maxDepth)
0126
0127 def getMaxBins(self):
0128 """
0129 Gets the value of maxBins or its default value.
0130 """
0131 return self.getOrDefault(self.maxBins)
0132
0133 def getMinInstancesPerNode(self):
0134 """
0135 Gets the value of minInstancesPerNode or its default value.
0136 """
0137 return self.getOrDefault(self.minInstancesPerNode)
0138
0139 def getMinWeightFractionPerNode(self):
0140 """
0141 Gets the value of minWeightFractionPerNode or its default value.
0142 """
0143 return self.getOrDefault(self.minWeightFractionPerNode)
0144
0145 def getMinInfoGain(self):
0146 """
0147 Gets the value of minInfoGain or its default value.
0148 """
0149 return self.getOrDefault(self.minInfoGain)
0150
0151 def getMaxMemoryInMB(self):
0152 """
0153 Gets the value of maxMemoryInMB or its default value.
0154 """
0155 return self.getOrDefault(self.maxMemoryInMB)
0156
0157 def getCacheNodeIds(self):
0158 """
0159 Gets the value of cacheNodeIds or its default value.
0160 """
0161 return self.getOrDefault(self.cacheNodeIds)
0162
0163
0164 @inherit_doc
0165 class _TreeEnsembleModel(JavaPredictionModel):
0166 """
0167 (private abstraction)
0168 Represents a tree ensemble model.
0169 """
0170
0171 @property
0172 @since("2.0.0")
0173 def trees(self):
0174 """Trees in this ensemble. Warning: These have null parent Estimators."""
0175 return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))]
0176
0177 @property
0178 @since("2.0.0")
0179 def getNumTrees(self):
0180 """Number of trees in ensemble."""
0181 return self._call_java("getNumTrees")
0182
0183 @property
0184 @since("1.5.0")
0185 def treeWeights(self):
0186 """Return the weights for each tree"""
0187 return list(self._call_java("javaTreeWeights"))
0188
0189 @property
0190 @since("2.0.0")
0191 def totalNumNodes(self):
0192 """Total number of nodes, summed over all trees in the ensemble."""
0193 return self._call_java("totalNumNodes")
0194
0195 @property
0196 @since("2.0.0")
0197 def toDebugString(self):
0198 """Full description of model."""
0199 return self._call_java("toDebugString")
0200
0201 @since("3.0.0")
0202 def predictLeaf(self, value):
0203 """
0204 Predict the indices of the leaves corresponding to the feature vector.
0205 """
0206 return self._call_java("predictLeaf", value)
0207
0208
0209 class _TreeEnsembleParams(_DecisionTreeParams):
0210 """
0211 Mixin for Decision Tree-based ensemble algorithms parameters.
0212 """
0213
0214 subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
0215 "used for learning each decision tree, in range (0, 1].",
0216 typeConverter=TypeConverters.toFloat)
0217
0218 supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
0219
0220 featureSubsetStrategy = \
0221 Param(Params._dummy(), "featureSubsetStrategy",
0222 "The number of features to consider for splits at each tree node. Supported " +
0223 "options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
0224 "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
0225 "'onethird' for regression), 'all' (use all features), 'onethird' (use " +
0226 "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
0227 "log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
0228 "n * number of features. When n is in the range (1, number of features), use" +
0229 " n features). default = 'auto'", typeConverter=TypeConverters.toString)
0230
0231 def __init__(self):
0232 super(_TreeEnsembleParams, self).__init__()
0233
0234 @since("1.4.0")
0235 def getSubsamplingRate(self):
0236 """
0237 Gets the value of subsamplingRate or its default value.
0238 """
0239 return self.getOrDefault(self.subsamplingRate)
0240
0241 @since("1.4.0")
0242 def getFeatureSubsetStrategy(self):
0243 """
0244 Gets the value of featureSubsetStrategy or its default value.
0245 """
0246 return self.getOrDefault(self.featureSubsetStrategy)
0247
0248
0249 class _RandomForestParams(_TreeEnsembleParams):
0250 """
0251 Private class to track supported random forest parameters.
0252 """
0253
0254 numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
0255 typeConverter=TypeConverters.toInt)
0256
0257 bootstrap = Param(Params._dummy(), "bootstrap", "Whether bootstrap samples are used "
0258 "when building trees.", typeConverter=TypeConverters.toBoolean)
0259
0260 def __init__(self):
0261 super(_RandomForestParams, self).__init__()
0262
0263 @since("1.4.0")
0264 def getNumTrees(self):
0265 """
0266 Gets the value of numTrees or its default value.
0267 """
0268 return self.getOrDefault(self.numTrees)
0269
0270 @since("3.0.0")
0271 def getBootstrap(self):
0272 """
0273 Gets the value of bootstrap or its default value.
0274 """
0275 return self.getOrDefault(self.bootstrap)
0276
0277
0278 class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
0279 """
0280 Private class to track supported GBT params.
0281 """
0282
0283 stepSize = Param(Params._dummy(), "stepSize",
0284 "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
0285 "the contribution of each estimator.",
0286 typeConverter=TypeConverters.toFloat)
0287
0288 validationTol = Param(Params._dummy(), "validationTol",
0289 "Threshold for stopping early when fit with validation is used. " +
0290 "If the error rate on the validation input changes by less than the " +
0291 "validationTol, then learning will stop early (before `maxIter`). " +
0292 "This parameter is ignored when fit without validation is used.",
0293 typeConverter=TypeConverters.toFloat)
0294
0295 @since("3.0.0")
0296 def getValidationTol(self):
0297 """
0298 Gets the value of validationTol or its default value.
0299 """
0300 return self.getOrDefault(self.validationTol)
0301
0302
0303 class _HasVarianceImpurity(Params):
0304 """
0305 Private class to track supported impurity measures.
0306 """
0307
0308 supportedImpurities = ["variance"]
0309
0310 impurity = Param(Params._dummy(), "impurity",
0311 "Criterion used for information gain calculation (case-insensitive). " +
0312 "Supported options: " +
0313 ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
0314
0315 def __init__(self):
0316 super(_HasVarianceImpurity, self).__init__()
0317
0318 @since("1.4.0")
0319 def getImpurity(self):
0320 """
0321 Gets the value of impurity or its default value.
0322 """
0323 return self.getOrDefault(self.impurity)
0324
0325
0326 class _TreeClassifierParams(Params):
0327 """
0328 Private class to track supported impurity measures.
0329
0330 .. versionadded:: 1.4.0
0331 """
0332
0333 supportedImpurities = ["entropy", "gini"]
0334
0335 impurity = Param(Params._dummy(), "impurity",
0336 "Criterion used for information gain calculation (case-insensitive). " +
0337 "Supported options: " +
0338 ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
0339
0340 def __init__(self):
0341 super(_TreeClassifierParams, self).__init__()
0342
0343 @since("1.6.0")
0344 def getImpurity(self):
0345 """
0346 Gets the value of impurity or its default value.
0347 """
0348 return self.getOrDefault(self.impurity)
0349
0350
0351 class _TreeRegressorParams(_HasVarianceImpurity):
0352 """
0353 Private class to track supported impurity measures.
0354 """
0355 pass