Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from pyspark import since, keyword_only
0019 from pyspark.ml.param.shared import *
0020 from pyspark.ml.util import *
0021 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
0022     JavaPredictor, JavaPredictionModel
0023 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0024 
0025 
0026 @inherit_doc
0027 class _DecisionTreeModel(JavaPredictionModel):
0028     """
0029     Abstraction for Decision Tree models.
0030 
0031     .. versionadded:: 1.5.0
0032     """
0033 
0034     @property
0035     @since("1.5.0")
0036     def numNodes(self):
0037         """Return number of nodes of the decision tree."""
0038         return self._call_java("numNodes")
0039 
0040     @property
0041     @since("1.5.0")
0042     def depth(self):
0043         """Return depth of the decision tree."""
0044         return self._call_java("depth")
0045 
0046     @property
0047     @since("2.0.0")
0048     def toDebugString(self):
0049         """Full description of model."""
0050         return self._call_java("toDebugString")
0051 
0052     @since("3.0.0")
0053     def predictLeaf(self, value):
0054         """
0055         Predict the indices of the leaves corresponding to the feature vector.
0056         """
0057         return self._call_java("predictLeaf", value)
0058 
0059 
0060 class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
0061     """
0062     Mixin for Decision Tree parameters.
0063     """
0064 
0065     leafCol = Param(Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " +
0066                     "index of each instance in each tree by preorder.",
0067                     typeConverter=TypeConverters.toString)
0068 
0069     maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " +
0070                      "depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
0071                      typeConverter=TypeConverters.toInt)
0072 
0073     maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " +
0074                     "features.  Must be >=2 and >= number of categories for any categorical " +
0075                     "feature.", typeConverter=TypeConverters.toInt)
0076 
0077     minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of " +
0078                                 "instances each child must have after split. If a split causes " +
0079                                 "the left or right child to have fewer than " +
0080                                 "minInstancesPerNode, the split will be discarded as invalid. " +
0081                                 "Should be >= 1.", typeConverter=TypeConverters.toInt)
0082 
0083     minWeightFractionPerNode = Param(Params._dummy(), "minWeightFractionPerNode", "Minimum "
0084                                      "fraction of the weighted sample count that each child "
0085                                      "must have after split. If a split causes the fraction "
0086                                      "of the total weight in the left or right child to be "
0087                                      "less than minWeightFractionPerNode, the split will be "
0088                                      "discarded as invalid. Should be in interval [0.0, 0.5).",
0089                                      typeConverter=TypeConverters.toFloat)
0090 
0091     minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split " +
0092                         "to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
0093 
0094     maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " +
0095                           "histogram aggregation. If too small, then 1 node will be split per " +
0096                           "iteration, and its aggregates may exceed this size.",
0097                           typeConverter=TypeConverters.toInt)
0098 
0099     cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " +
0100                          "trees to executors to match instances with nodes. If true, the " +
0101                          "algorithm will cache node IDs for each instance. Caching can speed " +
0102                          "up training of deeper trees. Users can set how often should the cache " +
0103                          "be checkpointed or disable it by setting checkpointInterval.",
0104                          typeConverter=TypeConverters.toBoolean)
0105 
0106     def __init__(self):
0107         super(_DecisionTreeParams, self).__init__()
0108 
0109     def setLeafCol(self, value):
0110         """
0111         Sets the value of :py:attr:`leafCol`.
0112         """
0113         return self._set(leafCol=value)
0114 
0115     def getLeafCol(self):
0116         """
0117         Gets the value of leafCol or its default value.
0118         """
0119         return self.getOrDefault(self.leafCol)
0120 
0121     def getMaxDepth(self):
0122         """
0123         Gets the value of maxDepth or its default value.
0124         """
0125         return self.getOrDefault(self.maxDepth)
0126 
0127     def getMaxBins(self):
0128         """
0129         Gets the value of maxBins or its default value.
0130         """
0131         return self.getOrDefault(self.maxBins)
0132 
0133     def getMinInstancesPerNode(self):
0134         """
0135         Gets the value of minInstancesPerNode or its default value.
0136         """
0137         return self.getOrDefault(self.minInstancesPerNode)
0138 
0139     def getMinWeightFractionPerNode(self):
0140         """
0141         Gets the value of minWeightFractionPerNode or its default value.
0142         """
0143         return self.getOrDefault(self.minWeightFractionPerNode)
0144 
0145     def getMinInfoGain(self):
0146         """
0147         Gets the value of minInfoGain or its default value.
0148         """
0149         return self.getOrDefault(self.minInfoGain)
0150 
0151     def getMaxMemoryInMB(self):
0152         """
0153         Gets the value of maxMemoryInMB or its default value.
0154         """
0155         return self.getOrDefault(self.maxMemoryInMB)
0156 
0157     def getCacheNodeIds(self):
0158         """
0159         Gets the value of cacheNodeIds or its default value.
0160         """
0161         return self.getOrDefault(self.cacheNodeIds)
0162 
0163 
0164 @inherit_doc
0165 class _TreeEnsembleModel(JavaPredictionModel):
0166     """
0167     (private abstraction)
0168     Represents a tree ensemble model.
0169     """
0170 
0171     @property
0172     @since("2.0.0")
0173     def trees(self):
0174         """Trees in this ensemble. Warning: These have null parent Estimators."""
0175         return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))]
0176 
0177     @property
0178     @since("2.0.0")
0179     def getNumTrees(self):
0180         """Number of trees in ensemble."""
0181         return self._call_java("getNumTrees")
0182 
0183     @property
0184     @since("1.5.0")
0185     def treeWeights(self):
0186         """Return the weights for each tree"""
0187         return list(self._call_java("javaTreeWeights"))
0188 
0189     @property
0190     @since("2.0.0")
0191     def totalNumNodes(self):
0192         """Total number of nodes, summed over all trees in the ensemble."""
0193         return self._call_java("totalNumNodes")
0194 
0195     @property
0196     @since("2.0.0")
0197     def toDebugString(self):
0198         """Full description of model."""
0199         return self._call_java("toDebugString")
0200 
0201     @since("3.0.0")
0202     def predictLeaf(self, value):
0203         """
0204         Predict the indices of the leaves corresponding to the feature vector.
0205         """
0206         return self._call_java("predictLeaf", value)
0207 
0208 
0209 class _TreeEnsembleParams(_DecisionTreeParams):
0210     """
0211     Mixin for Decision Tree-based ensemble algorithms parameters.
0212     """
0213 
0214     subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
0215                             "used for learning each decision tree, in range (0, 1].",
0216                             typeConverter=TypeConverters.toFloat)
0217 
0218     supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
0219 
0220     featureSubsetStrategy = \
0221         Param(Params._dummy(), "featureSubsetStrategy",
0222               "The number of features to consider for splits at each tree node. Supported " +
0223               "options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
0224               "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
0225               "'onethird' for regression), 'all' (use all features), 'onethird' (use " +
0226               "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
0227               "log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
0228               "n * number of features. When n is in the range (1, number of features), use" +
0229               " n features). default = 'auto'", typeConverter=TypeConverters.toString)
0230 
0231     def __init__(self):
0232         super(_TreeEnsembleParams, self).__init__()
0233 
0234     @since("1.4.0")
0235     def getSubsamplingRate(self):
0236         """
0237         Gets the value of subsamplingRate or its default value.
0238         """
0239         return self.getOrDefault(self.subsamplingRate)
0240 
0241     @since("1.4.0")
0242     def getFeatureSubsetStrategy(self):
0243         """
0244         Gets the value of featureSubsetStrategy or its default value.
0245         """
0246         return self.getOrDefault(self.featureSubsetStrategy)
0247 
0248 
0249 class _RandomForestParams(_TreeEnsembleParams):
0250     """
0251     Private class to track supported random forest parameters.
0252     """
0253 
0254     numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
0255                      typeConverter=TypeConverters.toInt)
0256 
0257     bootstrap = Param(Params._dummy(), "bootstrap", "Whether bootstrap samples are used "
0258                       "when building trees.", typeConverter=TypeConverters.toBoolean)
0259 
0260     def __init__(self):
0261         super(_RandomForestParams, self).__init__()
0262 
0263     @since("1.4.0")
0264     def getNumTrees(self):
0265         """
0266         Gets the value of numTrees or its default value.
0267         """
0268         return self.getOrDefault(self.numTrees)
0269 
0270     @since("3.0.0")
0271     def getBootstrap(self):
0272         """
0273         Gets the value of bootstrap or its default value.
0274         """
0275         return self.getOrDefault(self.bootstrap)
0276 
0277 
0278 class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
0279     """
0280     Private class to track supported GBT params.
0281     """
0282 
0283     stepSize = Param(Params._dummy(), "stepSize",
0284                      "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
0285                      "the contribution of each estimator.",
0286                      typeConverter=TypeConverters.toFloat)
0287 
0288     validationTol = Param(Params._dummy(), "validationTol",
0289                           "Threshold for stopping early when fit with validation is used. " +
0290                           "If the error rate on the validation input changes by less than the " +
0291                           "validationTol, then learning will stop early (before `maxIter`). " +
0292                           "This parameter is ignored when fit without validation is used.",
0293                           typeConverter=TypeConverters.toFloat)
0294 
0295     @since("3.0.0")
0296     def getValidationTol(self):
0297         """
0298         Gets the value of validationTol or its default value.
0299         """
0300         return self.getOrDefault(self.validationTol)
0301 
0302 
0303 class _HasVarianceImpurity(Params):
0304     """
0305     Private class to track supported impurity measures.
0306     """
0307 
0308     supportedImpurities = ["variance"]
0309 
0310     impurity = Param(Params._dummy(), "impurity",
0311                      "Criterion used for information gain calculation (case-insensitive). " +
0312                      "Supported options: " +
0313                      ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
0314 
0315     def __init__(self):
0316         super(_HasVarianceImpurity, self).__init__()
0317 
0318     @since("1.4.0")
0319     def getImpurity(self):
0320         """
0321         Gets the value of impurity or its default value.
0322         """
0323         return self.getOrDefault(self.impurity)
0324 
0325 
0326 class _TreeClassifierParams(Params):
0327     """
0328     Private class to track supported impurity measures.
0329 
0330     .. versionadded:: 1.4.0
0331     """
0332 
0333     supportedImpurities = ["entropy", "gini"]
0334 
0335     impurity = Param(Params._dummy(), "impurity",
0336                      "Criterion used for information gain calculation (case-insensitive). " +
0337                      "Supported options: " +
0338                      ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
0339 
0340     def __init__(self):
0341         super(_TreeClassifierParams, self).__init__()
0342 
0343     @since("1.6.0")
0344     def getImpurity(self):
0345         """
0346         Gets the value of impurity or its default value.
0347         """
0348         return self.getOrDefault(self.impurity)
0349 
0350 
0351 class _TreeRegressorParams(_HasVarianceImpurity):
0352     """
0353     Private class to track supported impurity measures.
0354     """
0355     pass