0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import operator
0019 import sys
0020 from multiprocessing.pool import ThreadPool
0021
0022 from pyspark import since, keyword_only
0023 from pyspark.ml import Estimator, Model
0024 from pyspark.ml.param.shared import *
0025 from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
0026 _TreeEnsembleModel, _RandomForestParams, _GBTParams, \
0027 _HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
0028 from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
0029 from pyspark.ml.util import *
0030 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
0031 JavaPredictor, _JavaPredictorParams, JavaPredictionModel, JavaWrapper
0032 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0033 from pyspark.ml.linalg import Vectors
0034 from pyspark.sql import DataFrame
0035 from pyspark.sql.functions import udf, when
0036 from pyspark.sql.types import ArrayType, DoubleType
0037 from pyspark.storagelevel import StorageLevel
0038
0039 __all__ = ['LinearSVC', 'LinearSVCModel',
0040 'LogisticRegression', 'LogisticRegressionModel',
0041 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
0042 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
0043 'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
0044 'GBTClassifier', 'GBTClassificationModel',
0045 'RandomForestClassifier', 'RandomForestClassificationModel',
0046 'NaiveBayes', 'NaiveBayesModel',
0047 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
0048 'OneVsRest', 'OneVsRestModel',
0049 'FMClassifier', 'FMClassificationModel']
0050
0051
0052 class _JavaClassifierParams(HasRawPredictionCol, _JavaPredictorParams):
0053 """
0054 Java Classifier Params for classification tasks.
0055
0056 .. versionadded:: 3.0.0
0057 """
0058 pass
0059
0060
0061 @inherit_doc
0062 class JavaClassifier(JavaPredictor, _JavaClassifierParams):
0063 """
0064 Java Classifier for classification tasks.
0065 Classes are indexed {0, 1, ..., numClasses - 1}.
0066 """
0067
0068 @since("3.0.0")
0069 def setRawPredictionCol(self, value):
0070 """
0071 Sets the value of :py:attr:`rawPredictionCol`.
0072 """
0073 return self._set(rawPredictionCol=value)
0074
0075
0076 @inherit_doc
0077 class JavaClassificationModel(JavaPredictionModel, _JavaClassifierParams):
0078 """
0079 Java Model produced by a ``Classifier``.
0080 Classes are indexed {0, 1, ..., numClasses - 1}.
0081 To be mixed in with class:`pyspark.ml.JavaModel`
0082 """
0083
0084 @since("3.0.0")
0085 def setRawPredictionCol(self, value):
0086 """
0087 Sets the value of :py:attr:`rawPredictionCol`.
0088 """
0089 return self._set(rawPredictionCol=value)
0090
0091 @property
0092 @since("2.1.0")
0093 def numClasses(self):
0094 """
0095 Number of classes (values which the label can take).
0096 """
0097 return self._call_java("numClasses")
0098
0099 @since("3.0.0")
0100 def predictRaw(self, value):
0101 """
0102 Raw prediction for each possible label.
0103 """
0104 return self._call_java("predictRaw", value)
0105
0106
0107 class _JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _JavaClassifierParams):
0108 """
0109 Params for :py:class:`JavaProbabilisticClassifier` and
0110 :py:class:`JavaProbabilisticClassificationModel`.
0111
0112 .. versionadded:: 3.0.0
0113 """
0114 pass
0115
0116
0117 @inherit_doc
0118 class JavaProbabilisticClassifier(JavaClassifier, _JavaProbabilisticClassifierParams):
0119 """
0120 Java Probabilistic Classifier for classification tasks.
0121 """
0122
0123 @since("3.0.0")
0124 def setProbabilityCol(self, value):
0125 """
0126 Sets the value of :py:attr:`probabilityCol`.
0127 """
0128 return self._set(probabilityCol=value)
0129
0130 @since("3.0.0")
0131 def setThresholds(self, value):
0132 """
0133 Sets the value of :py:attr:`thresholds`.
0134 """
0135 return self._set(thresholds=value)
0136
0137
0138 @inherit_doc
0139 class JavaProbabilisticClassificationModel(JavaClassificationModel,
0140 _JavaProbabilisticClassifierParams):
0141 """
0142 Java Model produced by a ``ProbabilisticClassifier``.
0143 """
0144
0145 @since("3.0.0")
0146 def setProbabilityCol(self, value):
0147 """
0148 Sets the value of :py:attr:`probabilityCol`.
0149 """
0150 return self._set(probabilityCol=value)
0151
0152 @since("3.0.0")
0153 def setThresholds(self, value):
0154 """
0155 Sets the value of :py:attr:`thresholds`.
0156 """
0157 return self._set(thresholds=value)
0158
0159 @since("3.0.0")
0160 def predictProbability(self, value):
0161 """
0162 Predict the probability of each class given the features.
0163 """
0164 return self._call_java("predictProbability", value)
0165
0166
0167 class _LinearSVCParams(_JavaClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
0168 HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold):
0169 """
0170 Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.
0171
0172 .. versionadded:: 3.0.0
0173 """
0174
0175 threshold = Param(Params._dummy(), "threshold",
0176 "The threshold in binary classification applied to the linear model"
0177 " prediction. This threshold can be any real number, where Inf will make"
0178 " all predictions 0.0 and -Inf will make all predictions 1.0.",
0179 typeConverter=TypeConverters.toFloat)
0180
0181
0182 @inherit_doc
0183 class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
0184 """
0185 `Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_
0186
0187 This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
0188 Only supports L2 regularization currently.
0189
0190 >>> from pyspark.sql import Row
0191 >>> from pyspark.ml.linalg import Vectors
0192 >>> df = sc.parallelize([
0193 ... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
0194 ... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
0195 >>> svm = LinearSVC()
0196 >>> svm.getMaxIter()
0197 100
0198 >>> svm.setMaxIter(5)
0199 LinearSVC...
0200 >>> svm.getMaxIter()
0201 5
0202 >>> svm.getRegParam()
0203 0.0
0204 >>> svm.setRegParam(0.01)
0205 LinearSVC...
0206 >>> svm.getRegParam()
0207 0.01
0208 >>> model = svm.fit(df)
0209 >>> model.setPredictionCol("newPrediction")
0210 LinearSVCModel...
0211 >>> model.getPredictionCol()
0212 'newPrediction'
0213 >>> model.setThreshold(0.5)
0214 LinearSVCModel...
0215 >>> model.getThreshold()
0216 0.5
0217 >>> model.coefficients
0218 DenseVector([0.0, -0.2792, -0.1833])
0219 >>> model.intercept
0220 1.0206118982229047
0221 >>> model.numClasses
0222 2
0223 >>> model.numFeatures
0224 3
0225 >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
0226 >>> model.predict(test0.head().features)
0227 1.0
0228 >>> model.predictRaw(test0.head().features)
0229 DenseVector([-1.4831, 1.4831])
0230 >>> result = model.transform(test0).head()
0231 >>> result.newPrediction
0232 1.0
0233 >>> result.rawPrediction
0234 DenseVector([-1.4831, 1.4831])
0235 >>> svm_path = temp_path + "/svm"
0236 >>> svm.save(svm_path)
0237 >>> svm2 = LinearSVC.load(svm_path)
0238 >>> svm2.getMaxIter()
0239 5
0240 >>> model_path = temp_path + "/svm_model"
0241 >>> model.save(model_path)
0242 >>> model2 = LinearSVCModel.load(model_path)
0243 >>> model.coefficients[0] == model2.coefficients[0]
0244 True
0245 >>> model.intercept == model2.intercept
0246 True
0247
0248 .. versionadded:: 2.2.0
0249 """
0250
0251 @keyword_only
0252 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0253 maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
0254 fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
0255 aggregationDepth=2):
0256 """
0257 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0258 maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
0259 fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
0260 aggregationDepth=2):
0261 """
0262 super(LinearSVC, self).__init__()
0263 self._java_obj = self._new_java_obj(
0264 "org.apache.spark.ml.classification.LinearSVC", self.uid)
0265 self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True,
0266 standardization=True, threshold=0.0, aggregationDepth=2)
0267 kwargs = self._input_kwargs
0268 self.setParams(**kwargs)
0269
0270 @keyword_only
0271 @since("2.2.0")
0272 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0273 maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
0274 fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
0275 aggregationDepth=2):
0276 """
0277 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0278 maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
0279 fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
0280 aggregationDepth=2):
0281 Sets params for Linear SVM Classifier.
0282 """
0283 kwargs = self._input_kwargs
0284 return self._set(**kwargs)
0285
0286 def _create_model(self, java_model):
0287 return LinearSVCModel(java_model)
0288
0289 @since("2.2.0")
0290 def setMaxIter(self, value):
0291 """
0292 Sets the value of :py:attr:`maxIter`.
0293 """
0294 return self._set(maxIter=value)
0295
0296 @since("2.2.0")
0297 def setRegParam(self, value):
0298 """
0299 Sets the value of :py:attr:`regParam`.
0300 """
0301 return self._set(regParam=value)
0302
0303 @since("2.2.0")
0304 def setTol(self, value):
0305 """
0306 Sets the value of :py:attr:`tol`.
0307 """
0308 return self._set(tol=value)
0309
0310 @since("2.2.0")
0311 def setFitIntercept(self, value):
0312 """
0313 Sets the value of :py:attr:`fitIntercept`.
0314 """
0315 return self._set(fitIntercept=value)
0316
0317 @since("2.2.0")
0318 def setStandardization(self, value):
0319 """
0320 Sets the value of :py:attr:`standardization`.
0321 """
0322 return self._set(standardization=value)
0323
0324 @since("2.2.0")
0325 def setThreshold(self, value):
0326 """
0327 Sets the value of :py:attr:`threshold`.
0328 """
0329 return self._set(threshold=value)
0330
0331 @since("2.2.0")
0332 def setWeightCol(self, value):
0333 """
0334 Sets the value of :py:attr:`weightCol`.
0335 """
0336 return self._set(weightCol=value)
0337
0338 @since("2.2.0")
0339 def setAggregationDepth(self, value):
0340 """
0341 Sets the value of :py:attr:`aggregationDepth`.
0342 """
0343 return self._set(aggregationDepth=value)
0344
0345
0346 class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
0347 """
0348 Model fitted by LinearSVC.
0349
0350 .. versionadded:: 2.2.0
0351 """
0352
0353 @since("3.0.0")
0354 def setThreshold(self, value):
0355 """
0356 Sets the value of :py:attr:`threshold`.
0357 """
0358 return self._set(threshold=value)
0359
0360 @property
0361 @since("2.2.0")
0362 def coefficients(self):
0363 """
0364 Model coefficients of Linear SVM Classifier.
0365 """
0366 return self._call_java("coefficients")
0367
0368 @property
0369 @since("2.2.0")
0370 def intercept(self):
0371 """
0372 Model intercept of Linear SVM Classifier.
0373 """
0374 return self._call_java("intercept")
0375
0376
0377 class _LogisticRegressionParams(_JavaProbabilisticClassifierParams, HasRegParam,
0378 HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol,
0379 HasStandardization, HasWeightCol, HasAggregationDepth,
0380 HasThreshold):
0381 """
0382 Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.
0383
0384 .. versionadded:: 3.0.0
0385 """
0386
0387 threshold = Param(Params._dummy(), "threshold",
0388 "Threshold in binary classification prediction, in range [0, 1]." +
0389 " If threshold and thresholds are both set, they must match." +
0390 "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
0391 typeConverter=TypeConverters.toFloat)
0392
0393 family = Param(Params._dummy(), "family",
0394 "The name of family which is a description of the label distribution to " +
0395 "be used in the model. Supported options: auto, binomial, multinomial",
0396 typeConverter=TypeConverters.toString)
0397
0398 lowerBoundsOnCoefficients = Param(Params._dummy(), "lowerBoundsOnCoefficients",
0399 "The lower bounds on coefficients if fitting under bound "
0400 "constrained optimization. The bound matrix must be "
0401 "compatible with the shape "
0402 "(1, number of features) for binomial regression, or "
0403 "(number of classes, number of features) "
0404 "for multinomial regression.",
0405 typeConverter=TypeConverters.toMatrix)
0406
0407 upperBoundsOnCoefficients = Param(Params._dummy(), "upperBoundsOnCoefficients",
0408 "The upper bounds on coefficients if fitting under bound "
0409 "constrained optimization. The bound matrix must be "
0410 "compatible with the shape "
0411 "(1, number of features) for binomial regression, or "
0412 "(number of classes, number of features) "
0413 "for multinomial regression.",
0414 typeConverter=TypeConverters.toMatrix)
0415
0416 lowerBoundsOnIntercepts = Param(Params._dummy(), "lowerBoundsOnIntercepts",
0417 "The lower bounds on intercepts if fitting under bound "
0418 "constrained optimization. The bounds vector size must be"
0419 "equal with 1 for binomial regression, or the number of"
0420 "lasses for multinomial regression.",
0421 typeConverter=TypeConverters.toVector)
0422
0423 upperBoundsOnIntercepts = Param(Params._dummy(), "upperBoundsOnIntercepts",
0424 "The upper bounds on intercepts if fitting under bound "
0425 "constrained optimization. The bound vector size must be "
0426 "equal with 1 for binomial regression, or the number of "
0427 "classes for multinomial regression.",
0428 typeConverter=TypeConverters.toVector)
0429
0430 @since("1.4.0")
0431 def setThreshold(self, value):
0432 """
0433 Sets the value of :py:attr:`threshold`.
0434 Clears value of :py:attr:`thresholds` if it has been set.
0435 """
0436 self._set(threshold=value)
0437 self.clear(self.thresholds)
0438 return self
0439
0440 @since("1.4.0")
0441 def getThreshold(self):
0442 """
0443 Get threshold for binary classification.
0444
0445 If :py:attr:`thresholds` is set with length 2 (i.e., binary classification),
0446 this returns the equivalent threshold:
0447 :math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`.
0448 Otherwise, returns :py:attr:`threshold` if set or its default value if unset.
0449 """
0450 self._checkThresholdConsistency()
0451 if self.isSet(self.thresholds):
0452 ts = self.getOrDefault(self.thresholds)
0453 if len(ts) != 2:
0454 raise ValueError("Logistic Regression getThreshold only applies to" +
0455 " binary classification, but thresholds has length != 2." +
0456 " thresholds: " + ",".join(ts))
0457 return 1.0/(1.0 + ts[0]/ts[1])
0458 else:
0459 return self.getOrDefault(self.threshold)
0460
0461 @since("1.5.0")
0462 def setThresholds(self, value):
0463 """
0464 Sets the value of :py:attr:`thresholds`.
0465 Clears value of :py:attr:`threshold` if it has been set.
0466 """
0467 self._set(thresholds=value)
0468 self.clear(self.threshold)
0469 return self
0470
0471 @since("1.5.0")
0472 def getThresholds(self):
0473 """
0474 If :py:attr:`thresholds` is set, return its value.
0475 Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
0476 classification: (1-threshold, threshold).
0477 If neither are set, throw an error.
0478 """
0479 self._checkThresholdConsistency()
0480 if not self.isSet(self.thresholds) and self.isSet(self.threshold):
0481 t = self.getOrDefault(self.threshold)
0482 return [1.0-t, t]
0483 else:
0484 return self.getOrDefault(self.thresholds)
0485
0486 def _checkThresholdConsistency(self):
0487 if self.isSet(self.threshold) and self.isSet(self.thresholds):
0488 ts = self.getOrDefault(self.thresholds)
0489 if len(ts) != 2:
0490 raise ValueError("Logistic Regression getThreshold only applies to" +
0491 " binary classification, but thresholds has length != 2." +
0492 " thresholds: {0}".format(str(ts)))
0493 t = 1.0/(1.0 + ts[0]/ts[1])
0494 t2 = self.getOrDefault(self.threshold)
0495 if abs(t2 - t) >= 1E-5:
0496 raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
0497 " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
0498
0499 @since("2.1.0")
0500 def getFamily(self):
0501 """
0502 Gets the value of :py:attr:`family` or its default value.
0503 """
0504 return self.getOrDefault(self.family)
0505
0506 @since("2.3.0")
0507 def getLowerBoundsOnCoefficients(self):
0508 """
0509 Gets the value of :py:attr:`lowerBoundsOnCoefficients`
0510 """
0511 return self.getOrDefault(self.lowerBoundsOnCoefficients)
0512
0513 @since("2.3.0")
0514 def getUpperBoundsOnCoefficients(self):
0515 """
0516 Gets the value of :py:attr:`upperBoundsOnCoefficients`
0517 """
0518 return self.getOrDefault(self.upperBoundsOnCoefficients)
0519
0520 @since("2.3.0")
0521 def getLowerBoundsOnIntercepts(self):
0522 """
0523 Gets the value of :py:attr:`lowerBoundsOnIntercepts`
0524 """
0525 return self.getOrDefault(self.lowerBoundsOnIntercepts)
0526
0527 @since("2.3.0")
0528 def getUpperBoundsOnIntercepts(self):
0529 """
0530 Gets the value of :py:attr:`upperBoundsOnIntercepts`
0531 """
0532 return self.getOrDefault(self.upperBoundsOnIntercepts)
0533
0534
0535 @inherit_doc
0536 class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable,
0537 JavaMLReadable):
0538 """
0539 Logistic regression.
0540 This class supports multinomial logistic (softmax) and binomial logistic regression.
0541
0542 >>> from pyspark.sql import Row
0543 >>> from pyspark.ml.linalg import Vectors
0544 >>> bdf = sc.parallelize([
0545 ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),
0546 ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),
0547 ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),
0548 ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
0549 >>> blor = LogisticRegression(weightCol="weight")
0550 >>> blor.getRegParam()
0551 0.0
0552 >>> blor.setRegParam(0.01)
0553 LogisticRegression...
0554 >>> blor.getRegParam()
0555 0.01
0556 >>> blor.setMaxIter(10)
0557 LogisticRegression...
0558 >>> blor.getMaxIter()
0559 10
0560 >>> blor.clear(blor.maxIter)
0561 >>> blorModel = blor.fit(bdf)
0562 >>> blorModel.setFeaturesCol("features")
0563 LogisticRegressionModel...
0564 >>> blorModel.setProbabilityCol("newProbability")
0565 LogisticRegressionModel...
0566 >>> blorModel.getProbabilityCol()
0567 'newProbability'
0568 >>> blorModel.setThreshold(0.1)
0569 LogisticRegressionModel...
0570 >>> blorModel.getThreshold()
0571 0.1
0572 >>> blorModel.coefficients
0573 DenseVector([-1.080..., -0.646...])
0574 >>> blorModel.intercept
0575 3.112...
0576 >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
0577 >>> mdf = spark.read.format("libsvm").load(data_path)
0578 >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial")
0579 >>> mlorModel = mlor.fit(mdf)
0580 >>> mlorModel.coefficientMatrix
0581 SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)
0582 >>> mlorModel.interceptVector
0583 DenseVector([0.04..., -0.42..., 0.37...])
0584 >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
0585 >>> blorModel.predict(test0.head().features)
0586 1.0
0587 >>> blorModel.predictRaw(test0.head().features)
0588 DenseVector([-3.54..., 3.54...])
0589 >>> blorModel.predictProbability(test0.head().features)
0590 DenseVector([0.028, 0.972])
0591 >>> result = blorModel.transform(test0).head()
0592 >>> result.prediction
0593 1.0
0594 >>> result.newProbability
0595 DenseVector([0.02..., 0.97...])
0596 >>> result.rawPrediction
0597 DenseVector([-3.54..., 3.54...])
0598 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
0599 >>> blorModel.transform(test1).head().prediction
0600 1.0
0601 >>> blor.setParams("vector")
0602 Traceback (most recent call last):
0603 ...
0604 TypeError: Method setParams forces keyword arguments.
0605 >>> lr_path = temp_path + "/lr"
0606 >>> blor.save(lr_path)
0607 >>> lr2 = LogisticRegression.load(lr_path)
0608 >>> lr2.getRegParam()
0609 0.01
0610 >>> model_path = temp_path + "/lr_model"
0611 >>> blorModel.save(model_path)
0612 >>> model2 = LogisticRegressionModel.load(model_path)
0613 >>> blorModel.coefficients[0] == model2.coefficients[0]
0614 True
0615 >>> blorModel.intercept == model2.intercept
0616 True
0617 >>> model2
0618 LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2
0619
0620 .. versionadded:: 1.3.0
0621 """
0622
0623 @keyword_only
0624 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0625 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0626 threshold=0.5, thresholds=None, probabilityCol="probability",
0627 rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
0628 aggregationDepth=2, family="auto",
0629 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
0630 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0631
0632 """
0633 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0634 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0635 threshold=0.5, thresholds=None, probabilityCol="probability", \
0636 rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
0637 aggregationDepth=2, family="auto", \
0638 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
0639 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0640 If the threshold and thresholds Params are both set, they must be equivalent.
0641 """
0642 super(LogisticRegression, self).__init__()
0643 self._java_obj = self._new_java_obj(
0644 "org.apache.spark.ml.classification.LogisticRegression", self.uid)
0645 self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto")
0646 kwargs = self._input_kwargs
0647 self.setParams(**kwargs)
0648 self._checkThresholdConsistency()
0649
0650 @keyword_only
0651 @since("1.3.0")
0652 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0653 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0654 threshold=0.5, thresholds=None, probabilityCol="probability",
0655 rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
0656 aggregationDepth=2, family="auto",
0657 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
0658 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0659 """
0660 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0661 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0662 threshold=0.5, thresholds=None, probabilityCol="probability", \
0663 rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
0664 aggregationDepth=2, family="auto", \
0665 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
0666 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0667 Sets params for logistic regression.
0668 If the threshold and thresholds Params are both set, they must be equivalent.
0669 """
0670 kwargs = self._input_kwargs
0671 self._set(**kwargs)
0672 self._checkThresholdConsistency()
0673 return self
0674
0675 def _create_model(self, java_model):
0676 return LogisticRegressionModel(java_model)
0677
0678 @since("2.1.0")
0679 def setFamily(self, value):
0680 """
0681 Sets the value of :py:attr:`family`.
0682 """
0683 return self._set(family=value)
0684
0685 @since("2.3.0")
0686 def setLowerBoundsOnCoefficients(self, value):
0687 """
0688 Sets the value of :py:attr:`lowerBoundsOnCoefficients`
0689 """
0690 return self._set(lowerBoundsOnCoefficients=value)
0691
0692 @since("2.3.0")
0693 def setUpperBoundsOnCoefficients(self, value):
0694 """
0695 Sets the value of :py:attr:`upperBoundsOnCoefficients`
0696 """
0697 return self._set(upperBoundsOnCoefficients=value)
0698
0699 @since("2.3.0")
0700 def setLowerBoundsOnIntercepts(self, value):
0701 """
0702 Sets the value of :py:attr:`lowerBoundsOnIntercepts`
0703 """
0704 return self._set(lowerBoundsOnIntercepts=value)
0705
0706 @since("2.3.0")
0707 def setUpperBoundsOnIntercepts(self, value):
0708 """
0709 Sets the value of :py:attr:`upperBoundsOnIntercepts`
0710 """
0711 return self._set(upperBoundsOnIntercepts=value)
0712
0713 def setMaxIter(self, value):
0714 """
0715 Sets the value of :py:attr:`maxIter`.
0716 """
0717 return self._set(maxIter=value)
0718
0719 def setRegParam(self, value):
0720 """
0721 Sets the value of :py:attr:`regParam`.
0722 """
0723 return self._set(regParam=value)
0724
0725 def setTol(self, value):
0726 """
0727 Sets the value of :py:attr:`tol`.
0728 """
0729 return self._set(tol=value)
0730
0731 def setElasticNetParam(self, value):
0732 """
0733 Sets the value of :py:attr:`elasticNetParam`.
0734 """
0735 return self._set(elasticNetParam=value)
0736
0737 def setFitIntercept(self, value):
0738 """
0739 Sets the value of :py:attr:`fitIntercept`.
0740 """
0741 return self._set(fitIntercept=value)
0742
0743 def setStandardization(self, value):
0744 """
0745 Sets the value of :py:attr:`standardization`.
0746 """
0747 return self._set(standardization=value)
0748
0749 def setWeightCol(self, value):
0750 """
0751 Sets the value of :py:attr:`weightCol`.
0752 """
0753 return self._set(weightCol=value)
0754
0755 def setAggregationDepth(self, value):
0756 """
0757 Sets the value of :py:attr:`aggregationDepth`.
0758 """
0759 return self._set(aggregationDepth=value)
0760
0761
0762 class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticRegressionParams,
0763 JavaMLWritable, JavaMLReadable, HasTrainingSummary):
0764 """
0765 Model fitted by LogisticRegression.
0766
0767 .. versionadded:: 1.3.0
0768 """
0769
0770 @property
0771 @since("2.0.0")
0772 def coefficients(self):
0773 """
0774 Model coefficients of binomial logistic regression.
0775 An exception is thrown in the case of multinomial logistic regression.
0776 """
0777 return self._call_java("coefficients")
0778
0779 @property
0780 @since("1.4.0")
0781 def intercept(self):
0782 """
0783 Model intercept of binomial logistic regression.
0784 An exception is thrown in the case of multinomial logistic regression.
0785 """
0786 return self._call_java("intercept")
0787
0788 @property
0789 @since("2.1.0")
0790 def coefficientMatrix(self):
0791 """
0792 Model coefficients.
0793 """
0794 return self._call_java("coefficientMatrix")
0795
0796 @property
0797 @since("2.1.0")
0798 def interceptVector(self):
0799 """
0800 Model intercept.
0801 """
0802 return self._call_java("interceptVector")
0803
0804 @property
0805 @since("2.0.0")
0806 def summary(self):
0807 """
0808 Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
0809 trained on the training set. An exception is thrown if `trainingSummary is None`.
0810 """
0811 if self.hasSummary:
0812 if self.numClasses <= 2:
0813 return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
0814 self).summary)
0815 else:
0816 return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
0817 self).summary)
0818 else:
0819 raise RuntimeError("No training summary available for this %s" %
0820 self.__class__.__name__)
0821
0822 @since("2.0.0")
0823 def evaluate(self, dataset):
0824 """
0825 Evaluates the model on a test dataset.
0826
0827 :param dataset:
0828 Test dataset to evaluate model on, where dataset is an
0829 instance of :py:class:`pyspark.sql.DataFrame`
0830 """
0831 if not isinstance(dataset, DataFrame):
0832 raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
0833 java_blr_summary = self._call_java("evaluate", dataset)
0834 if self.numClasses <= 2:
0835 return BinaryLogisticRegressionSummary(java_blr_summary)
0836 else:
0837 return LogisticRegressionSummary(java_blr_summary)
0838
0839
0840 class LogisticRegressionSummary(JavaWrapper):
0841 """
0842 Abstraction for Logistic Regression Results for a given model.
0843
0844 .. versionadded:: 2.0.0
0845 """
0846
0847 @property
0848 @since("2.0.0")
0849 def predictions(self):
0850 """
0851 Dataframe outputted by the model's `transform` method.
0852 """
0853 return self._call_java("predictions")
0854
0855 @property
0856 @since("2.0.0")
0857 def probabilityCol(self):
0858 """
0859 Field in "predictions" which gives the probability
0860 of each class as a vector.
0861 """
0862 return self._call_java("probabilityCol")
0863
0864 @property
0865 @since("2.3.0")
0866 def predictionCol(self):
0867 """
0868 Field in "predictions" which gives the prediction of each class.
0869 """
0870 return self._call_java("predictionCol")
0871
0872 @property
0873 @since("2.0.0")
0874 def labelCol(self):
0875 """
0876 Field in "predictions" which gives the true label of each
0877 instance.
0878 """
0879 return self._call_java("labelCol")
0880
0881 @property
0882 @since("2.0.0")
0883 def featuresCol(self):
0884 """
0885 Field in "predictions" which gives the features of each instance
0886 as a vector.
0887 """
0888 return self._call_java("featuresCol")
0889
0890 @property
0891 @since("2.3.0")
0892 def labels(self):
0893 """
0894 Returns the sequence of labels in ascending order. This order matches the order used
0895 in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
0896
0897 Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
0898 training set is missing a label, then all of the arrays over labels
0899 (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
0900 expected numClasses.
0901 """
0902 return self._call_java("labels")
0903
0904 @property
0905 @since("2.3.0")
0906 def truePositiveRateByLabel(self):
0907 """
0908 Returns true positive rate for each label (category).
0909 """
0910 return self._call_java("truePositiveRateByLabel")
0911
0912 @property
0913 @since("2.3.0")
0914 def falsePositiveRateByLabel(self):
0915 """
0916 Returns false positive rate for each label (category).
0917 """
0918 return self._call_java("falsePositiveRateByLabel")
0919
0920 @property
0921 @since("2.3.0")
0922 def precisionByLabel(self):
0923 """
0924 Returns precision for each label (category).
0925 """
0926 return self._call_java("precisionByLabel")
0927
0928 @property
0929 @since("2.3.0")
0930 def recallByLabel(self):
0931 """
0932 Returns recall for each label (category).
0933 """
0934 return self._call_java("recallByLabel")
0935
0936 @since("2.3.0")
0937 def fMeasureByLabel(self, beta=1.0):
0938 """
0939 Returns f-measure for each label (category).
0940 """
0941 return self._call_java("fMeasureByLabel", beta)
0942
0943 @property
0944 @since("2.3.0")
0945 def accuracy(self):
0946 """
0947 Returns accuracy.
0948 (equals to the total number of correctly classified instances
0949 out of the total number of instances.)
0950 """
0951 return self._call_java("accuracy")
0952
0953 @property
0954 @since("2.3.0")
0955 def weightedTruePositiveRate(self):
0956 """
0957 Returns weighted true positive rate.
0958 (equals to precision, recall and f-measure)
0959 """
0960 return self._call_java("weightedTruePositiveRate")
0961
0962 @property
0963 @since("2.3.0")
0964 def weightedFalsePositiveRate(self):
0965 """
0966 Returns weighted false positive rate.
0967 """
0968 return self._call_java("weightedFalsePositiveRate")
0969
0970 @property
0971 @since("2.3.0")
0972 def weightedRecall(self):
0973 """
0974 Returns weighted averaged recall.
0975 (equals to precision, recall and f-measure)
0976 """
0977 return self._call_java("weightedRecall")
0978
0979 @property
0980 @since("2.3.0")
0981 def weightedPrecision(self):
0982 """
0983 Returns weighted averaged precision.
0984 """
0985 return self._call_java("weightedPrecision")
0986
0987 @since("2.3.0")
0988 def weightedFMeasure(self, beta=1.0):
0989 """
0990 Returns weighted averaged f-measure.
0991 """
0992 return self._call_java("weightedFMeasure", beta)
0993
0994
0995 @inherit_doc
0996 class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
0997 """
0998 Abstraction for multinomial Logistic Regression Training results.
0999 Currently, the training summary ignores the training weights except
1000 for the objective trace.
1001
1002 .. versionadded:: 2.0.0
1003 """
1004
1005 @property
1006 @since("2.0.0")
1007 def objectiveHistory(self):
1008 """
1009 Objective function (scaled loss + regularization) at each
1010 iteration.
1011 """
1012 return self._call_java("objectiveHistory")
1013
1014 @property
1015 @since("2.0.0")
1016 def totalIterations(self):
1017 """
1018 Number of training iterations until termination.
1019 """
1020 return self._call_java("totalIterations")
1021
1022
1023 @inherit_doc
1024 class BinaryLogisticRegressionSummary(LogisticRegressionSummary):
1025 """
1026 Binary Logistic regression results for a given model.
1027
1028 .. versionadded:: 2.0.0
1029 """
1030
1031 @property
1032 @since("2.0.0")
1033 def roc(self):
1034 """
1035 Returns the receiver operating characteristic (ROC) curve,
1036 which is a Dataframe having two fields (FPR, TPR) with
1037 (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
1038
1039 .. seealso:: `Wikipedia reference
1040 <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
1041
1042 .. note:: This ignores instance weights (setting all to 1.0) from
1043 `LogisticRegression.weightCol`. This will change in later Spark
1044 versions.
1045 """
1046 return self._call_java("roc")
1047
1048 @property
1049 @since("2.0.0")
1050 def areaUnderROC(self):
1051 """
1052 Computes the area under the receiver operating characteristic
1053 (ROC) curve.
1054
1055 .. note:: This ignores instance weights (setting all to 1.0) from
1056 `LogisticRegression.weightCol`. This will change in later Spark
1057 versions.
1058 """
1059 return self._call_java("areaUnderROC")
1060
1061 @property
1062 @since("2.0.0")
1063 def pr(self):
1064 """
1065 Returns the precision-recall curve, which is a Dataframe
1066 containing two fields recall, precision with (0.0, 1.0) prepended
1067 to it.
1068
1069 .. note:: This ignores instance weights (setting all to 1.0) from
1070 `LogisticRegression.weightCol`. This will change in later Spark
1071 versions.
1072 """
1073 return self._call_java("pr")
1074
1075 @property
1076 @since("2.0.0")
1077 def fMeasureByThreshold(self):
1078 """
1079 Returns a dataframe with two fields (threshold, F-Measure) curve
1080 with beta = 1.0.
1081
1082 .. note:: This ignores instance weights (setting all to 1.0) from
1083 `LogisticRegression.weightCol`. This will change in later Spark
1084 versions.
1085 """
1086 return self._call_java("fMeasureByThreshold")
1087
1088 @property
1089 @since("2.0.0")
1090 def precisionByThreshold(self):
1091 """
1092 Returns a dataframe with two fields (threshold, precision) curve.
1093 Every possible probability obtained in transforming the dataset
1094 are used as thresholds used in calculating the precision.
1095
1096 .. note:: This ignores instance weights (setting all to 1.0) from
1097 `LogisticRegression.weightCol`. This will change in later Spark
1098 versions.
1099 """
1100 return self._call_java("precisionByThreshold")
1101
1102 @property
1103 @since("2.0.0")
1104 def recallByThreshold(self):
1105 """
1106 Returns a dataframe with two fields (threshold, recall) curve.
1107 Every possible probability obtained in transforming the dataset
1108 are used as thresholds used in calculating the recall.
1109
1110 .. note:: This ignores instance weights (setting all to 1.0) from
1111 `LogisticRegression.weightCol`. This will change in later Spark
1112 versions.
1113 """
1114 return self._call_java("recallByThreshold")
1115
1116
1117 @inherit_doc
1118 class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
1119 LogisticRegressionTrainingSummary):
1120 """
1121 Binary Logistic regression training results for a given model.
1122
1123 .. versionadded:: 2.0.0
1124 """
1125 pass
1126
1127
1128 @inherit_doc
1129 class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
1130 """
1131 Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.
1132 """
1133 pass
1134
1135
1136 @inherit_doc
1137 class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
1138 JavaMLWritable, JavaMLReadable):
1139 """
1140 `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
1141 learning algorithm for classification.
1142 It supports both binary and multiclass labels, as well as both continuous and categorical
1143 features.
1144
1145 >>> from pyspark.ml.linalg import Vectors
1146 >>> from pyspark.ml.feature import StringIndexer
1147 >>> df = spark.createDataFrame([
1148 ... (1.0, Vectors.dense(1.0)),
1149 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1150 >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1151 >>> si_model = stringIndexer.fit(df)
1152 >>> td = si_model.transform(df)
1153 >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")
1154 >>> model = dt.fit(td)
1155 >>> model.getLabelCol()
1156 'indexed'
1157 >>> model.setFeaturesCol("features")
1158 DecisionTreeClassificationModel...
1159 >>> model.numNodes
1160 3
1161 >>> model.depth
1162 1
1163 >>> model.featureImportances
1164 SparseVector(1, {0: 1.0})
1165 >>> model.numFeatures
1166 1
1167 >>> model.numClasses
1168 2
1169 >>> print(model.toDebugString)
1170 DecisionTreeClassificationModel...depth=1, numNodes=3...
1171 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1172 >>> model.predict(test0.head().features)
1173 0.0
1174 >>> model.predictRaw(test0.head().features)
1175 DenseVector([1.0, 0.0])
1176 >>> model.predictProbability(test0.head().features)
1177 DenseVector([1.0, 0.0])
1178 >>> result = model.transform(test0).head()
1179 >>> result.prediction
1180 0.0
1181 >>> result.probability
1182 DenseVector([1.0, 0.0])
1183 >>> result.rawPrediction
1184 DenseVector([1.0, 0.0])
1185 >>> result.leafId
1186 0.0
1187 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1188 >>> model.transform(test1).head().prediction
1189 1.0
1190 >>> dtc_path = temp_path + "/dtc"
1191 >>> dt.save(dtc_path)
1192 >>> dt2 = DecisionTreeClassifier.load(dtc_path)
1193 >>> dt2.getMaxDepth()
1194 2
1195 >>> model_path = temp_path + "/dtc_model"
1196 >>> model.save(model_path)
1197 >>> model2 = DecisionTreeClassificationModel.load(model_path)
1198 >>> model.featureImportances == model2.featureImportances
1199 True
1200
1201 >>> df3 = spark.createDataFrame([
1202 ... (1.0, 0.2, Vectors.dense(1.0)),
1203 ... (1.0, 0.8, Vectors.dense(1.0)),
1204 ... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
1205 >>> si3 = StringIndexer(inputCol="label", outputCol="indexed")
1206 >>> si_model3 = si3.fit(df3)
1207 >>> td3 = si_model3.transform(df3)
1208 >>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
1209 >>> model3 = dt3.fit(td3)
1210 >>> print(model3.toDebugString)
1211 DecisionTreeClassificationModel...depth=1, numNodes=3...
1212
1213 .. versionadded:: 1.4.0
1214 """
1215
1216 @keyword_only
1217 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1218 probabilityCol="probability", rawPredictionCol="rawPrediction",
1219 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1220 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
1221 seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
1222 """
1223 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1224 probabilityCol="probability", rawPredictionCol="rawPrediction", \
1225 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1226 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1227 seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
1228 """
1229 super(DecisionTreeClassifier, self).__init__()
1230 self._java_obj = self._new_java_obj(
1231 "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
1232 self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1233 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1234 impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
1235 kwargs = self._input_kwargs
1236 self.setParams(**kwargs)
1237
1238 @keyword_only
1239 @since("1.4.0")
1240 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1241 probabilityCol="probability", rawPredictionCol="rawPrediction",
1242 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1243 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1244 impurity="gini", seed=None, weightCol=None, leafCol="",
1245 minWeightFractionPerNode=0.0):
1246 """
1247 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1248 probabilityCol="probability", rawPredictionCol="rawPrediction", \
1249 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1250 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1251 seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
1252 Sets params for the DecisionTreeClassifier.
1253 """
1254 kwargs = self._input_kwargs
1255 return self._set(**kwargs)
1256
1257 def _create_model(self, java_model):
1258 return DecisionTreeClassificationModel(java_model)
1259
1260 def setMaxDepth(self, value):
1261 """
1262 Sets the value of :py:attr:`maxDepth`.
1263 """
1264 return self._set(maxDepth=value)
1265
1266 def setMaxBins(self, value):
1267 """
1268 Sets the value of :py:attr:`maxBins`.
1269 """
1270 return self._set(maxBins=value)
1271
1272 def setMinInstancesPerNode(self, value):
1273 """
1274 Sets the value of :py:attr:`minInstancesPerNode`.
1275 """
1276 return self._set(minInstancesPerNode=value)
1277
1278 @since("3.0.0")
1279 def setMinWeightFractionPerNode(self, value):
1280 """
1281 Sets the value of :py:attr:`minWeightFractionPerNode`.
1282 """
1283 return self._set(minWeightFractionPerNode=value)
1284
1285 def setMinInfoGain(self, value):
1286 """
1287 Sets the value of :py:attr:`minInfoGain`.
1288 """
1289 return self._set(minInfoGain=value)
1290
1291 def setMaxMemoryInMB(self, value):
1292 """
1293 Sets the value of :py:attr:`maxMemoryInMB`.
1294 """
1295 return self._set(maxMemoryInMB=value)
1296
1297 def setCacheNodeIds(self, value):
1298 """
1299 Sets the value of :py:attr:`cacheNodeIds`.
1300 """
1301 return self._set(cacheNodeIds=value)
1302
1303 @since("1.4.0")
1304 def setImpurity(self, value):
1305 """
1306 Sets the value of :py:attr:`impurity`.
1307 """
1308 return self._set(impurity=value)
1309
1310 @since("1.4.0")
1311 def setCheckpointInterval(self, value):
1312 """
1313 Sets the value of :py:attr:`checkpointInterval`.
1314 """
1315 return self._set(checkpointInterval=value)
1316
1317 def setSeed(self, value):
1318 """
1319 Sets the value of :py:attr:`seed`.
1320 """
1321 return self._set(seed=value)
1322
1323 @since("3.0.0")
1324 def setWeightCol(self, value):
1325 """
1326 Sets the value of :py:attr:`weightCol`.
1327 """
1328 return self._set(weightCol=value)
1329
1330
1331 @inherit_doc
1332 class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
1333 _DecisionTreeClassifierParams, JavaMLWritable,
1334 JavaMLReadable):
1335 """
1336 Model fitted by DecisionTreeClassifier.
1337
1338 .. versionadded:: 1.4.0
1339 """
1340
1341 @property
1342 @since("2.0.0")
1343 def featureImportances(self):
1344 """
1345 Estimate of the importance of each feature.
1346
1347 This generalizes the idea of "Gini" importance to other losses,
1348 following the explanation of Gini importance from "Random Forests" documentation
1349 by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
1350
1351 This feature importance is calculated as follows:
1352 - importance(feature j) = sum (over nodes which split on feature j) of the gain,
1353 where gain is scaled by the number of instances passing through node
1354 - Normalize importances for tree to sum to 1.
1355
1356 .. note:: Feature importance for single decision trees can have high variance due to
1357 correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
1358 to determine feature importance instead.
1359 """
1360 return self._call_java("featureImportances")
1361
1362
1363 @inherit_doc
1364 class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
1365 """
1366 Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.
1367 """
1368 pass
1369
1370
1371 @inherit_doc
1372 class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
1373 JavaMLWritable, JavaMLReadable):
1374 """
1375 `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
1376 learning algorithm for classification.
1377 It supports both binary and multiclass labels, as well as both continuous and categorical
1378 features.
1379
1380 >>> import numpy
1381 >>> from numpy import allclose
1382 >>> from pyspark.ml.linalg import Vectors
1383 >>> from pyspark.ml.feature import StringIndexer
1384 >>> df = spark.createDataFrame([
1385 ... (1.0, Vectors.dense(1.0)),
1386 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1387 >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1388 >>> si_model = stringIndexer.fit(df)
1389 >>> td = si_model.transform(df)
1390 >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
1391 ... leafCol="leafId")
1392 >>> rf.getMinWeightFractionPerNode()
1393 0.0
1394 >>> model = rf.fit(td)
1395 >>> model.getLabelCol()
1396 'indexed'
1397 >>> model.setFeaturesCol("features")
1398 RandomForestClassificationModel...
1399 >>> model.setRawPredictionCol("newRawPrediction")
1400 RandomForestClassificationModel...
1401 >>> model.getBootstrap()
1402 True
1403 >>> model.getRawPredictionCol()
1404 'newRawPrediction'
1405 >>> model.featureImportances
1406 SparseVector(1, {0: 1.0})
1407 >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
1408 True
1409 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1410 >>> model.predict(test0.head().features)
1411 0.0
1412 >>> model.predictRaw(test0.head().features)
1413 DenseVector([2.0, 0.0])
1414 >>> model.predictProbability(test0.head().features)
1415 DenseVector([1.0, 0.0])
1416 >>> result = model.transform(test0).head()
1417 >>> result.prediction
1418 0.0
1419 >>> numpy.argmax(result.probability)
1420 0
1421 >>> numpy.argmax(result.newRawPrediction)
1422 0
1423 >>> result.leafId
1424 DenseVector([0.0, 0.0, 0.0])
1425 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1426 >>> model.transform(test1).head().prediction
1427 1.0
1428 >>> model.trees
1429 [DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]
1430 >>> rfc_path = temp_path + "/rfc"
1431 >>> rf.save(rfc_path)
1432 >>> rf2 = RandomForestClassifier.load(rfc_path)
1433 >>> rf2.getNumTrees()
1434 3
1435 >>> model_path = temp_path + "/rfc_model"
1436 >>> model.save(model_path)
1437 >>> model2 = RandomForestClassificationModel.load(model_path)
1438 >>> model.featureImportances == model2.featureImportances
1439 True
1440
1441 .. versionadded:: 1.4.0
1442 """
1443
1444 @keyword_only
1445 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1446 probabilityCol="probability", rawPredictionCol="rawPrediction",
1447 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1448 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
1449 numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
1450 leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
1451 """
1452 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1453 probabilityCol="probability", rawPredictionCol="rawPrediction", \
1454 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1455 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1456 numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
1457 leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
1458 """
1459 super(RandomForestClassifier, self).__init__()
1460 self._java_obj = self._new_java_obj(
1461 "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
1462 self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1463 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1464 impurity="gini", numTrees=20, featureSubsetStrategy="auto",
1465 subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0,
1466 bootstrap=True)
1467 kwargs = self._input_kwargs
1468 self.setParams(**kwargs)
1469
1470 @keyword_only
1471 @since("1.4.0")
1472 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1473 probabilityCol="probability", rawPredictionCol="rawPrediction",
1474 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1475 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
1476 impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
1477 leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
1478 """
1479 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1480 probabilityCol="probability", rawPredictionCol="rawPrediction", \
1481 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1482 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
1483 impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
1484 leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
1485 Sets params for linear classification.
1486 """
1487 kwargs = self._input_kwargs
1488 return self._set(**kwargs)
1489
1490 def _create_model(self, java_model):
1491 return RandomForestClassificationModel(java_model)
1492
1493 def setMaxDepth(self, value):
1494 """
1495 Sets the value of :py:attr:`maxDepth`.
1496 """
1497 return self._set(maxDepth=value)
1498
1499 def setMaxBins(self, value):
1500 """
1501 Sets the value of :py:attr:`maxBins`.
1502 """
1503 return self._set(maxBins=value)
1504
1505 def setMinInstancesPerNode(self, value):
1506 """
1507 Sets the value of :py:attr:`minInstancesPerNode`.
1508 """
1509 return self._set(minInstancesPerNode=value)
1510
1511 def setMinInfoGain(self, value):
1512 """
1513 Sets the value of :py:attr:`minInfoGain`.
1514 """
1515 return self._set(minInfoGain=value)
1516
1517 def setMaxMemoryInMB(self, value):
1518 """
1519 Sets the value of :py:attr:`maxMemoryInMB`.
1520 """
1521 return self._set(maxMemoryInMB=value)
1522
1523 def setCacheNodeIds(self, value):
1524 """
1525 Sets the value of :py:attr:`cacheNodeIds`.
1526 """
1527 return self._set(cacheNodeIds=value)
1528
1529 @since("1.4.0")
1530 def setImpurity(self, value):
1531 """
1532 Sets the value of :py:attr:`impurity`.
1533 """
1534 return self._set(impurity=value)
1535
1536 @since("1.4.0")
1537 def setNumTrees(self, value):
1538 """
1539 Sets the value of :py:attr:`numTrees`.
1540 """
1541 return self._set(numTrees=value)
1542
1543 @since("3.0.0")
1544 def setBootstrap(self, value):
1545 """
1546 Sets the value of :py:attr:`bootstrap`.
1547 """
1548 return self._set(bootstrap=value)
1549
1550 @since("1.4.0")
1551 def setSubsamplingRate(self, value):
1552 """
1553 Sets the value of :py:attr:`subsamplingRate`.
1554 """
1555 return self._set(subsamplingRate=value)
1556
1557 @since("2.4.0")
1558 def setFeatureSubsetStrategy(self, value):
1559 """
1560 Sets the value of :py:attr:`featureSubsetStrategy`.
1561 """
1562 return self._set(featureSubsetStrategy=value)
1563
1564 def setSeed(self, value):
1565 """
1566 Sets the value of :py:attr:`seed`.
1567 """
1568 return self._set(seed=value)
1569
1570 def setCheckpointInterval(self, value):
1571 """
1572 Sets the value of :py:attr:`checkpointInterval`.
1573 """
1574 return self._set(checkpointInterval=value)
1575
1576 @since("3.0.0")
1577 def setWeightCol(self, value):
1578 """
1579 Sets the value of :py:attr:`weightCol`.
1580 """
1581 return self._set(weightCol=value)
1582
1583 @since("3.0.0")
1584 def setMinWeightFractionPerNode(self, value):
1585 """
1586 Sets the value of :py:attr:`minWeightFractionPerNode`.
1587 """
1588 return self._set(minWeightFractionPerNode=value)
1589
1590
1591 class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
1592 _RandomForestClassifierParams, JavaMLWritable,
1593 JavaMLReadable):
1594 """
1595 Model fitted by RandomForestClassifier.
1596
1597 .. versionadded:: 1.4.0
1598 """
1599
1600 @property
1601 @since("2.0.0")
1602 def featureImportances(self):
1603 """
1604 Estimate of the importance of each feature.
1605
1606 Each feature's importance is the average of its importance across all trees in the ensemble
1607 The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1608 (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1609 and follows the implementation from scikit-learn.
1610
1611 .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
1612 """
1613 return self._call_java("featureImportances")
1614
1615 @property
1616 @since("2.0.0")
1617 def trees(self):
1618 """Trees in this ensemble. Warning: These have null parent Estimators."""
1619 return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
1620
1621
1622 class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
1623 """
1624 Params for :py:class:`GBTClassifier` and :py:class:`GBTClassifierModel`.
1625
1626 .. versionadded:: 3.0.0
1627 """
1628
1629 supportedLossTypes = ["logistic"]
1630
1631 lossType = Param(Params._dummy(), "lossType",
1632 "Loss function which GBT tries to minimize (case-insensitive). " +
1633 "Supported options: " + ", ".join(supportedLossTypes),
1634 typeConverter=TypeConverters.toString)
1635
1636 @since("1.4.0")
1637 def getLossType(self):
1638 """
1639 Gets the value of lossType or its default value.
1640 """
1641 return self.getOrDefault(self.lossType)
1642
1643
1644 @inherit_doc
1645 class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
1646 JavaMLWritable, JavaMLReadable):
1647 """
1648 `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
1649 learning algorithm for classification.
1650 It supports binary labels, as well as both continuous and categorical features.
1651
1652 The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
1653
1654 Notes on Gradient Boosting vs. TreeBoost:
1655 - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
1656 - Both algorithms learn tree ensembles by minimizing loss functions.
1657 - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
1658 based on the loss function, whereas the original gradient boosting method does not.
1659 - We expect to implement TreeBoost in the future:
1660 `SPARK-4240 <https://issues.apache.org/jira/browse/SPARK-4240>`_
1661
1662 .. note:: Multiclass labels are not currently supported.
1663
1664 >>> from numpy import allclose
1665 >>> from pyspark.ml.linalg import Vectors
1666 >>> from pyspark.ml.feature import StringIndexer
1667 >>> df = spark.createDataFrame([
1668 ... (1.0, Vectors.dense(1.0)),
1669 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1670 >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1671 >>> si_model = stringIndexer.fit(df)
1672 >>> td = si_model.transform(df)
1673 >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,
1674 ... leafCol="leafId")
1675 >>> gbt.setMaxIter(5)
1676 GBTClassifier...
1677 >>> gbt.setMinWeightFractionPerNode(0.049)
1678 GBTClassifier...
1679 >>> gbt.getMaxIter()
1680 5
1681 >>> gbt.getFeatureSubsetStrategy()
1682 'all'
1683 >>> model = gbt.fit(td)
1684 >>> model.getLabelCol()
1685 'indexed'
1686 >>> model.setFeaturesCol("features")
1687 GBTClassificationModel...
1688 >>> model.setThresholds([0.3, 0.7])
1689 GBTClassificationModel...
1690 >>> model.getThresholds()
1691 [0.3, 0.7]
1692 >>> model.featureImportances
1693 SparseVector(1, {0: 1.0})
1694 >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
1695 True
1696 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1697 >>> model.predict(test0.head().features)
1698 0.0
1699 >>> model.predictRaw(test0.head().features)
1700 DenseVector([1.1697, -1.1697])
1701 >>> model.predictProbability(test0.head().features)
1702 DenseVector([0.9121, 0.0879])
1703 >>> result = model.transform(test0).head()
1704 >>> result.prediction
1705 0.0
1706 >>> result.leafId
1707 DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1708 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1709 >>> model.transform(test1).head().prediction
1710 1.0
1711 >>> model.totalNumNodes
1712 15
1713 >>> print(model.toDebugString)
1714 GBTClassificationModel...numTrees=5...
1715 >>> gbtc_path = temp_path + "gbtc"
1716 >>> gbt.save(gbtc_path)
1717 >>> gbt2 = GBTClassifier.load(gbtc_path)
1718 >>> gbt2.getMaxDepth()
1719 2
1720 >>> model_path = temp_path + "gbtc_model"
1721 >>> model.save(model_path)
1722 >>> model2 = GBTClassificationModel.load(model_path)
1723 >>> model.featureImportances == model2.featureImportances
1724 True
1725 >>> model.treeWeights == model2.treeWeights
1726 True
1727 >>> model.trees
1728 [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1729 >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
1730 ... ["indexed", "features"])
1731 >>> model.evaluateEachIteration(validation)
1732 [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
1733 >>> model.numClasses
1734 2
1735 >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
1736 >>> gbt.getValidationIndicatorCol()
1737 'validationIndicator'
1738 >>> gbt.getValidationTol()
1739 0.01
1740
1741 .. versionadded:: 1.4.0
1742 """
1743
1744 @keyword_only
1745 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1746 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1747 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
1748 maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
1749 featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
1750 leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
1751 """
1752 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1753 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1754 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1755 lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
1756 impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1757 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1758 weightCol=None)
1759 """
1760 super(GBTClassifier, self).__init__()
1761 self._java_obj = self._new_java_obj(
1762 "org.apache.spark.ml.classification.GBTClassifier", self.uid)
1763 self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1764 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1765 lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
1766 impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1767 leafCol="", minWeightFractionPerNode=0.0)
1768 kwargs = self._input_kwargs
1769 self.setParams(**kwargs)
1770
1771 @keyword_only
1772 @since("1.4.0")
1773 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1774 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1775 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1776 lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
1777 impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1778 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1779 weightCol=None):
1780 """
1781 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1782 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1783 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1784 lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
1785 impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1786 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1787 weightCol=None)
1788 Sets params for Gradient Boosted Tree Classification.
1789 """
1790 kwargs = self._input_kwargs
1791 return self._set(**kwargs)
1792
1793 def _create_model(self, java_model):
1794 return GBTClassificationModel(java_model)
1795
1796 def setMaxDepth(self, value):
1797 """
1798 Sets the value of :py:attr:`maxDepth`.
1799 """
1800 return self._set(maxDepth=value)
1801
1802 def setMaxBins(self, value):
1803 """
1804 Sets the value of :py:attr:`maxBins`.
1805 """
1806 return self._set(maxBins=value)
1807
1808 def setMinInstancesPerNode(self, value):
1809 """
1810 Sets the value of :py:attr:`minInstancesPerNode`.
1811 """
1812 return self._set(minInstancesPerNode=value)
1813
1814 def setMinInfoGain(self, value):
1815 """
1816 Sets the value of :py:attr:`minInfoGain`.
1817 """
1818 return self._set(minInfoGain=value)
1819
1820 def setMaxMemoryInMB(self, value):
1821 """
1822 Sets the value of :py:attr:`maxMemoryInMB`.
1823 """
1824 return self._set(maxMemoryInMB=value)
1825
1826 def setCacheNodeIds(self, value):
1827 """
1828 Sets the value of :py:attr:`cacheNodeIds`.
1829 """
1830 return self._set(cacheNodeIds=value)
1831
1832 @since("1.4.0")
1833 def setImpurity(self, value):
1834 """
1835 Sets the value of :py:attr:`impurity`.
1836 """
1837 return self._set(impurity=value)
1838
1839 @since("1.4.0")
1840 def setLossType(self, value):
1841 """
1842 Sets the value of :py:attr:`lossType`.
1843 """
1844 return self._set(lossType=value)
1845
1846 @since("1.4.0")
1847 def setSubsamplingRate(self, value):
1848 """
1849 Sets the value of :py:attr:`subsamplingRate`.
1850 """
1851 return self._set(subsamplingRate=value)
1852
1853 @since("2.4.0")
1854 def setFeatureSubsetStrategy(self, value):
1855 """
1856 Sets the value of :py:attr:`featureSubsetStrategy`.
1857 """
1858 return self._set(featureSubsetStrategy=value)
1859
1860 @since("3.0.0")
1861 def setValidationIndicatorCol(self, value):
1862 """
1863 Sets the value of :py:attr:`validationIndicatorCol`.
1864 """
1865 return self._set(validationIndicatorCol=value)
1866
1867 @since("1.4.0")
1868 def setMaxIter(self, value):
1869 """
1870 Sets the value of :py:attr:`maxIter`.
1871 """
1872 return self._set(maxIter=value)
1873
1874 @since("1.4.0")
1875 def setCheckpointInterval(self, value):
1876 """
1877 Sets the value of :py:attr:`checkpointInterval`.
1878 """
1879 return self._set(checkpointInterval=value)
1880
1881 @since("1.4.0")
1882 def setSeed(self, value):
1883 """
1884 Sets the value of :py:attr:`seed`.
1885 """
1886 return self._set(seed=value)
1887
1888 @since("1.4.0")
1889 def setStepSize(self, value):
1890 """
1891 Sets the value of :py:attr:`stepSize`.
1892 """
1893 return self._set(stepSize=value)
1894
1895 @since("3.0.0")
1896 def setWeightCol(self, value):
1897 """
1898 Sets the value of :py:attr:`weightCol`.
1899 """
1900 return self._set(weightCol=value)
1901
1902 @since("3.0.0")
1903 def setMinWeightFractionPerNode(self, value):
1904 """
1905 Sets the value of :py:attr:`minWeightFractionPerNode`.
1906 """
1907 return self._set(minWeightFractionPerNode=value)
1908
1909
1910 class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
1911 _GBTClassifierParams, JavaMLWritable, JavaMLReadable):
1912 """
1913 Model fitted by GBTClassifier.
1914
1915 .. versionadded:: 1.4.0
1916 """
1917
1918 @property
1919 @since("2.0.0")
1920 def featureImportances(self):
1921 """
1922 Estimate of the importance of each feature.
1923
1924 Each feature's importance is the average of its importance across all trees in the ensemble
1925 The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1926 (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1927 and follows the implementation from scikit-learn.
1928
1929 .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
1930 """
1931 return self._call_java("featureImportances")
1932
1933 @property
1934 @since("2.0.0")
1935 def trees(self):
1936 """Trees in this ensemble. Warning: These have null parent Estimators."""
1937 return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1938
1939 @since("2.4.0")
1940 def evaluateEachIteration(self, dataset):
1941 """
1942 Method to compute error or loss for every iteration of gradient boosting.
1943
1944 :param dataset:
1945 Test dataset to evaluate model on, where dataset is an
1946 instance of :py:class:`pyspark.sql.DataFrame`
1947 """
1948 return self._call_java("evaluateEachIteration", dataset)
1949
1950
1951 class _NaiveBayesParams(_JavaPredictorParams, HasWeightCol):
1952 """
1953 Params for :py:class:`NaiveBayes` and :py:class:`NaiveBayesModel`.
1954
1955 .. versionadded:: 3.0.0
1956 """
1957
1958 smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
1959 "default is 1.0", typeConverter=TypeConverters.toFloat)
1960 modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
1961 "(case-sensitive). Supported options: multinomial (default), bernoulli " +
1962 "and gaussian.",
1963 typeConverter=TypeConverters.toString)
1964
1965 @since("1.5.0")
1966 def getSmoothing(self):
1967 """
1968 Gets the value of smoothing or its default value.
1969 """
1970 return self.getOrDefault(self.smoothing)
1971
1972 @since("1.5.0")
1973 def getModelType(self):
1974 """
1975 Gets the value of modelType or its default value.
1976 """
1977 return self.getOrDefault(self.modelType)
1978
1979
1980 @inherit_doc
1981 class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, HasWeightCol,
1982 JavaMLWritable, JavaMLReadable):
1983 """
1984 Naive Bayes Classifiers.
1985 It supports both Multinomial and Bernoulli NB. `Multinomial NB
1986 <http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html>`_
1987 can handle finitely supported discrete data. For example, by converting documents into
1988 TF-IDF vectors, it can be used for document classification. By making every vector a
1989 binary (0/1) data, it can also be used as `Bernoulli NB
1990 <http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html>`_.
1991 The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.
1992 Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB.
1993 Specifically, Complement NB uses statistics from the complement of each class to compute
1994 the model's coefficients. The inventors of Complement NB show empirically that the parameter
1995 estimates for CNB are more stable than those for Multinomial NB. Like Multinomial NB, the
1996 input feature values for Complement NB must be nonnegative.
1997 Since 3.0.0, it also supports Gaussian NB
1998 <https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes>`_.
1999 which can handle continuous data.
2000
2001 >>> from pyspark.sql import Row
2002 >>> from pyspark.ml.linalg import Vectors
2003 >>> df = spark.createDataFrame([
2004 ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
2005 ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
2006 ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])
2007 >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
2008 >>> model = nb.fit(df)
2009 >>> model.setFeaturesCol("features")
2010 NaiveBayesModel...
2011 >>> model.getSmoothing()
2012 1.0
2013 >>> model.pi
2014 DenseVector([-0.81..., -0.58...])
2015 >>> model.theta
2016 DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
2017 >>> model.sigma
2018 DenseMatrix(0, 0, [...], ...)
2019 >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
2020 >>> model.predict(test0.head().features)
2021 1.0
2022 >>> model.predictRaw(test0.head().features)
2023 DenseVector([-1.72..., -0.99...])
2024 >>> model.predictProbability(test0.head().features)
2025 DenseVector([0.32..., 0.67...])
2026 >>> result = model.transform(test0).head()
2027 >>> result.prediction
2028 1.0
2029 >>> result.probability
2030 DenseVector([0.32..., 0.67...])
2031 >>> result.rawPrediction
2032 DenseVector([-1.72..., -0.99...])
2033 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
2034 >>> model.transform(test1).head().prediction
2035 1.0
2036 >>> nb_path = temp_path + "/nb"
2037 >>> nb.save(nb_path)
2038 >>> nb2 = NaiveBayes.load(nb_path)
2039 >>> nb2.getSmoothing()
2040 1.0
2041 >>> model_path = temp_path + "/nb_model"
2042 >>> model.save(model_path)
2043 >>> model2 = NaiveBayesModel.load(model_path)
2044 >>> model.pi == model2.pi
2045 True
2046 >>> model.theta == model2.theta
2047 True
2048 >>> nb = nb.setThresholds([0.01, 10.00])
2049 >>> model3 = nb.fit(df)
2050 >>> result = model3.transform(test0).head()
2051 >>> result.prediction
2052 0.0
2053 >>> nb3 = NaiveBayes().setModelType("gaussian")
2054 >>> model4 = nb3.fit(df)
2055 >>> model4.getModelType()
2056 'gaussian'
2057 >>> model4.sigma
2058 DenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)
2059 >>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")
2060 >>> model5 = nb5.fit(df)
2061 >>> model5.getModelType()
2062 'complement'
2063 >>> model5.theta
2064 DenseMatrix(2, 2, [...], 1)
2065 >>> model5.sigma
2066 DenseMatrix(0, 0, [...], ...)
2067
2068 .. versionadded:: 1.5.0
2069 """
2070
2071 @keyword_only
2072 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2073 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
2074 modelType="multinomial", thresholds=None, weightCol=None):
2075 """
2076 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2077 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
2078 modelType="multinomial", thresholds=None, weightCol=None)
2079 """
2080 super(NaiveBayes, self).__init__()
2081 self._java_obj = self._new_java_obj(
2082 "org.apache.spark.ml.classification.NaiveBayes", self.uid)
2083 self._setDefault(smoothing=1.0, modelType="multinomial")
2084 kwargs = self._input_kwargs
2085 self.setParams(**kwargs)
2086
2087 @keyword_only
2088 @since("1.5.0")
2089 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2090 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
2091 modelType="multinomial", thresholds=None, weightCol=None):
2092 """
2093 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2094 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
2095 modelType="multinomial", thresholds=None, weightCol=None)
2096 Sets params for Naive Bayes.
2097 """
2098 kwargs = self._input_kwargs
2099 return self._set(**kwargs)
2100
2101 def _create_model(self, java_model):
2102 return NaiveBayesModel(java_model)
2103
2104 @since("1.5.0")
2105 def setSmoothing(self, value):
2106 """
2107 Sets the value of :py:attr:`smoothing`.
2108 """
2109 return self._set(smoothing=value)
2110
2111 @since("1.5.0")
2112 def setModelType(self, value):
2113 """
2114 Sets the value of :py:attr:`modelType`.
2115 """
2116 return self._set(modelType=value)
2117
2118 def setWeightCol(self, value):
2119 """
2120 Sets the value of :py:attr:`weightCol`.
2121 """
2122 return self._set(weightCol=value)
2123
2124
2125 class NaiveBayesModel(JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
2126 JavaMLReadable):
2127 """
2128 Model fitted by NaiveBayes.
2129
2130 .. versionadded:: 1.5.0
2131 """
2132
2133 @property
2134 @since("2.0.0")
2135 def pi(self):
2136 """
2137 log of class priors.
2138 """
2139 return self._call_java("pi")
2140
2141 @property
2142 @since("2.0.0")
2143 def theta(self):
2144 """
2145 log of class conditional probabilities.
2146 """
2147 return self._call_java("theta")
2148
2149 @property
2150 @since("3.0.0")
2151 def sigma(self):
2152 """
2153 variance of each feature.
2154 """
2155 return self._call_java("sigma")
2156
2157
2158 class _MultilayerPerceptronParams(_JavaProbabilisticClassifierParams, HasSeed, HasMaxIter,
2159 HasTol, HasStepSize, HasSolver, HasBlockSize):
2160 """
2161 Params for :py:class:`MultilayerPerceptronClassifier`.
2162
2163 .. versionadded:: 3.0.0
2164 """
2165
2166 layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
2167 "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
2168 "neurons and output layer of 10 neurons.",
2169 typeConverter=TypeConverters.toListInt)
2170 solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
2171 "options: l-bfgs, gd.", typeConverter=TypeConverters.toString)
2172 initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
2173 typeConverter=TypeConverters.toVector)
2174
2175 @since("1.6.0")
2176 def getLayers(self):
2177 """
2178 Gets the value of layers or its default value.
2179 """
2180 return self.getOrDefault(self.layers)
2181
2182 @since("2.0.0")
2183 def getInitialWeights(self):
2184 """
2185 Gets the value of initialWeights or its default value.
2186 """
2187 return self.getOrDefault(self.initialWeights)
2188
2189
2190 @inherit_doc
2191 class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPerceptronParams,
2192 JavaMLWritable, JavaMLReadable):
2193 """
2194 Classifier trainer based on the Multilayer Perceptron.
2195 Each layer has sigmoid activation function, output layer has softmax.
2196 Number of inputs has to be equal to the size of feature vectors.
2197 Number of outputs has to be equal to the total number of labels.
2198
2199 >>> from pyspark.ml.linalg import Vectors
2200 >>> df = spark.createDataFrame([
2201 ... (0.0, Vectors.dense([0.0, 0.0])),
2202 ... (1.0, Vectors.dense([0.0, 1.0])),
2203 ... (1.0, Vectors.dense([1.0, 0.0])),
2204 ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
2205 >>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
2206 >>> mlp.setMaxIter(100)
2207 MultilayerPerceptronClassifier...
2208 >>> mlp.getMaxIter()
2209 100
2210 >>> mlp.getBlockSize()
2211 128
2212 >>> mlp.setBlockSize(1)
2213 MultilayerPerceptronClassifier...
2214 >>> mlp.getBlockSize()
2215 1
2216 >>> model = mlp.fit(df)
2217 >>> model.setFeaturesCol("features")
2218 MultilayerPerceptronClassificationModel...
2219 >>> model.getMaxIter()
2220 100
2221 >>> model.getLayers()
2222 [2, 2, 2]
2223 >>> model.weights.size
2224 12
2225 >>> testDF = spark.createDataFrame([
2226 ... (Vectors.dense([1.0, 0.0]),),
2227 ... (Vectors.dense([0.0, 0.0]),)], ["features"])
2228 >>> model.predict(testDF.head().features)
2229 1.0
2230 >>> model.predictRaw(testDF.head().features)
2231 DenseVector([-16.208, 16.344])
2232 >>> model.predictProbability(testDF.head().features)
2233 DenseVector([0.0, 1.0])
2234 >>> model.transform(testDF).select("features", "prediction").show()
2235 +---------+----------+
2236 | features|prediction|
2237 +---------+----------+
2238 |[1.0,0.0]| 1.0|
2239 |[0.0,0.0]| 0.0|
2240 +---------+----------+
2241 ...
2242 >>> mlp_path = temp_path + "/mlp"
2243 >>> mlp.save(mlp_path)
2244 >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
2245 >>> mlp2.getBlockSize()
2246 1
2247 >>> model_path = temp_path + "/mlp_model"
2248 >>> model.save(model_path)
2249 >>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
2250 >>> model.getLayers() == model2.getLayers()
2251 True
2252 >>> model.weights == model2.weights
2253 True
2254 >>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))
2255 >>> model3 = mlp2.fit(df)
2256 >>> model3.weights != model2.weights
2257 True
2258 >>> model3.getLayers() == model.getLayers()
2259 True
2260
2261 .. versionadded:: 1.6.0
2262 """
2263
2264 @keyword_only
2265 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2266 maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
2267 solver="l-bfgs", initialWeights=None, probabilityCol="probability",
2268 rawPredictionCol="rawPrediction"):
2269 """
2270 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2271 maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
2272 solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
2273 rawPredictionCol="rawPrediction")
2274 """
2275 super(MultilayerPerceptronClassifier, self).__init__()
2276 self._java_obj = self._new_java_obj(
2277 "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
2278 self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
2279 kwargs = self._input_kwargs
2280 self.setParams(**kwargs)
2281
2282 @keyword_only
2283 @since("1.6.0")
2284 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2285 maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
2286 solver="l-bfgs", initialWeights=None, probabilityCol="probability",
2287 rawPredictionCol="rawPrediction"):
2288 """
2289 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2290 maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
2291 solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
2292 rawPredictionCol="rawPrediction"):
2293 Sets params for MultilayerPerceptronClassifier.
2294 """
2295 kwargs = self._input_kwargs
2296 return self._set(**kwargs)
2297
2298 def _create_model(self, java_model):
2299 return MultilayerPerceptronClassificationModel(java_model)
2300
2301 @since("1.6.0")
2302 def setLayers(self, value):
2303 """
2304 Sets the value of :py:attr:`layers`.
2305 """
2306 return self._set(layers=value)
2307
2308 @since("1.6.0")
2309 def setBlockSize(self, value):
2310 """
2311 Sets the value of :py:attr:`blockSize`.
2312 """
2313 return self._set(blockSize=value)
2314
2315 @since("2.0.0")
2316 def setInitialWeights(self, value):
2317 """
2318 Sets the value of :py:attr:`initialWeights`.
2319 """
2320 return self._set(initialWeights=value)
2321
2322 def setMaxIter(self, value):
2323 """
2324 Sets the value of :py:attr:`maxIter`.
2325 """
2326 return self._set(maxIter=value)
2327
2328 def setSeed(self, value):
2329 """
2330 Sets the value of :py:attr:`seed`.
2331 """
2332 return self._set(seed=value)
2333
2334 def setTol(self, value):
2335 """
2336 Sets the value of :py:attr:`tol`.
2337 """
2338 return self._set(tol=value)
2339
2340 @since("2.0.0")
2341 def setStepSize(self, value):
2342 """
2343 Sets the value of :py:attr:`stepSize`.
2344 """
2345 return self._set(stepSize=value)
2346
2347 def setSolver(self, value):
2348 """
2349 Sets the value of :py:attr:`solver`.
2350 """
2351 return self._set(solver=value)
2352
2353
2354 class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel,
2355 _MultilayerPerceptronParams, JavaMLWritable,
2356 JavaMLReadable):
2357 """
2358 Model fitted by MultilayerPerceptronClassifier.
2359
2360 .. versionadded:: 1.6.0
2361 """
2362
2363 @property
2364 @since("2.0.0")
2365 def weights(self):
2366 """
2367 the weights of layers.
2368 """
2369 return self._call_java("weights")
2370
2371
2372 class _OneVsRestParams(_JavaClassifierParams, HasWeightCol):
2373 """
2374 Params for :py:class:`OneVsRest` and :py:class:`OneVsRestModelModel`.
2375 """
2376
2377 classifier = Param(Params._dummy(), "classifier", "base binary classifier")
2378
2379 @since("2.0.0")
2380 def getClassifier(self):
2381 """
2382 Gets the value of classifier or its default value.
2383 """
2384 return self.getOrDefault(self.classifier)
2385
2386
2387 @inherit_doc
2388 class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, JavaMLWritable):
2389 """
2390 Reduction of Multiclass Classification to Binary Classification.
2391 Performs reduction using one against all strategy.
2392 For a multiclass classification with k classes, train k models (one per class).
2393 Each example is scored against all k models and the model with highest score
2394 is picked to label the example.
2395
2396 >>> from pyspark.sql import Row
2397 >>> from pyspark.ml.linalg import Vectors
2398 >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
2399 >>> df = spark.read.format("libsvm").load(data_path)
2400 >>> lr = LogisticRegression(regParam=0.01)
2401 >>> ovr = OneVsRest(classifier=lr)
2402 >>> ovr.getRawPredictionCol()
2403 'rawPrediction'
2404 >>> ovr.setPredictionCol("newPrediction")
2405 OneVsRest...
2406 >>> model = ovr.fit(df)
2407 >>> model.models[0].coefficients
2408 DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
2409 >>> model.models[1].coefficients
2410 DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])
2411 >>> model.models[2].coefficients
2412 DenseVector([0.3..., -3.4..., 1.0..., -1.1...])
2413 >>> [x.intercept for x in model.models]
2414 [-2.7..., -2.5..., -1.3...]
2415 >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()
2416 >>> model.transform(test0).head().newPrediction
2417 0.0
2418 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()
2419 >>> model.transform(test1).head().newPrediction
2420 2.0
2421 >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()
2422 >>> model.transform(test2).head().newPrediction
2423 0.0
2424 >>> model_path = temp_path + "/ovr_model"
2425 >>> model.save(model_path)
2426 >>> model2 = OneVsRestModel.load(model_path)
2427 >>> model2.transform(test0).head().newPrediction
2428 0.0
2429 >>> model.transform(test2).columns
2430 ['features', 'rawPrediction', 'newPrediction']
2431
2432 .. versionadded:: 2.0.0
2433 """
2434
2435 @keyword_only
2436 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2437 rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2438 """
2439 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2440 rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2441 """
2442 super(OneVsRest, self).__init__()
2443 self._setDefault(parallelism=1)
2444 kwargs = self._input_kwargs
2445 self._set(**kwargs)
2446
2447 @keyword_only
2448 @since("2.0.0")
2449 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2450 rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2451 """
2452 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2453 rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2454 Sets params for OneVsRest.
2455 """
2456 kwargs = self._input_kwargs
2457 return self._set(**kwargs)
2458
2459 @since("2.0.0")
2460 def setClassifier(self, value):
2461 """
2462 Sets the value of :py:attr:`classifier`.
2463 """
2464 return self._set(classifier=value)
2465
2466 def setLabelCol(self, value):
2467 """
2468 Sets the value of :py:attr:`labelCol`.
2469 """
2470 return self._set(labelCol=value)
2471
2472 def setFeaturesCol(self, value):
2473 """
2474 Sets the value of :py:attr:`featuresCol`.
2475 """
2476 return self._set(featuresCol=value)
2477
2478 def setPredictionCol(self, value):
2479 """
2480 Sets the value of :py:attr:`predictionCol`.
2481 """
2482 return self._set(predictionCol=value)
2483
2484 def setRawPredictionCol(self, value):
2485 """
2486 Sets the value of :py:attr:`rawPredictionCol`.
2487 """
2488 return self._set(rawPredictionCol=value)
2489
2490 def setWeightCol(self, value):
2491 """
2492 Sets the value of :py:attr:`weightCol`.
2493 """
2494 return self._set(weightCol=value)
2495
2496 def setParallelism(self, value):
2497 """
2498 Sets the value of :py:attr:`parallelism`.
2499 """
2500 return self._set(parallelism=value)
2501
2502 def _fit(self, dataset):
2503 labelCol = self.getLabelCol()
2504 featuresCol = self.getFeaturesCol()
2505 predictionCol = self.getPredictionCol()
2506 classifier = self.getClassifier()
2507
2508 numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
2509
2510 weightCol = None
2511 if (self.isDefined(self.weightCol) and self.getWeightCol()):
2512 if isinstance(classifier, HasWeightCol):
2513 weightCol = self.getWeightCol()
2514 else:
2515 warnings.warn("weightCol is ignored, "
2516 "as it is not supported by {} now.".format(classifier))
2517
2518 if weightCol:
2519 multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
2520 else:
2521 multiclassLabeled = dataset.select(labelCol, featuresCol)
2522
2523
2524 handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
2525 if handlePersistence:
2526 multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
2527
2528 def trainSingleClass(index):
2529 binaryLabelCol = "mc2b$" + str(index)
2530 trainingDataset = multiclassLabeled.withColumn(
2531 binaryLabelCol,
2532 when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0))
2533 paramMap = dict([(classifier.labelCol, binaryLabelCol),
2534 (classifier.featuresCol, featuresCol),
2535 (classifier.predictionCol, predictionCol)])
2536 if weightCol:
2537 paramMap[classifier.weightCol] = weightCol
2538 return classifier.fit(trainingDataset, paramMap)
2539
2540 pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
2541
2542 models = pool.map(trainSingleClass, range(numClasses))
2543
2544 if handlePersistence:
2545 multiclassLabeled.unpersist()
2546
2547 return self._copyValues(OneVsRestModel(models=models))
2548
2549 @since("2.0.0")
2550 def copy(self, extra=None):
2551 """
2552 Creates a copy of this instance with a randomly generated uid
2553 and some extra params. This creates a deep copy of the embedded paramMap,
2554 and copies the embedded and extra parameters over.
2555
2556 :param extra: Extra parameters to copy to the new instance
2557 :return: Copy of this instance
2558 """
2559 if extra is None:
2560 extra = dict()
2561 newOvr = Params.copy(self, extra)
2562 if self.isSet(self.classifier):
2563 newOvr.setClassifier(self.getClassifier().copy(extra))
2564 return newOvr
2565
2566 @classmethod
2567 def _from_java(cls, java_stage):
2568 """
2569 Given a Java OneVsRest, create and return a Python wrapper of it.
2570 Used for ML persistence.
2571 """
2572 featuresCol = java_stage.getFeaturesCol()
2573 labelCol = java_stage.getLabelCol()
2574 predictionCol = java_stage.getPredictionCol()
2575 rawPredictionCol = java_stage.getRawPredictionCol()
2576 classifier = JavaParams._from_java(java_stage.getClassifier())
2577 parallelism = java_stage.getParallelism()
2578 py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
2579 rawPredictionCol=rawPredictionCol, classifier=classifier,
2580 parallelism=parallelism)
2581 if java_stage.isDefined(java_stage.getParam("weightCol")):
2582 py_stage.setWeightCol(java_stage.getWeightCol())
2583 py_stage._resetUid(java_stage.uid())
2584 return py_stage
2585
2586 def _to_java(self):
2587 """
2588 Transfer this instance to a Java OneVsRest. Used for ML persistence.
2589
2590 :return: Java object equivalent to this instance.
2591 """
2592 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
2593 self.uid)
2594 _java_obj.setClassifier(self.getClassifier()._to_java())
2595 _java_obj.setParallelism(self.getParallelism())
2596 _java_obj.setFeaturesCol(self.getFeaturesCol())
2597 _java_obj.setLabelCol(self.getLabelCol())
2598 _java_obj.setPredictionCol(self.getPredictionCol())
2599 if (self.isDefined(self.weightCol) and self.getWeightCol()):
2600 _java_obj.setWeightCol(self.getWeightCol())
2601 _java_obj.setRawPredictionCol(self.getRawPredictionCol())
2602 return _java_obj
2603
2604 def _make_java_param_pair(self, param, value):
2605 """
2606 Makes a Java param pair.
2607 """
2608 sc = SparkContext._active_spark_context
2609 param = self._resolveParam(param)
2610 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
2611 self.uid)
2612 java_param = _java_obj.getParam(param.name)
2613 if isinstance(value, JavaParams):
2614
2615
2616
2617 java_value = value._to_java()
2618 else:
2619 java_value = _py2java(sc, value)
2620 return java_param.w(java_value)
2621
2622 def _transfer_param_map_to_java(self, pyParamMap):
2623 """
2624 Transforms a Python ParamMap into a Java ParamMap.
2625 """
2626 paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
2627 for param in self.params:
2628 if param in pyParamMap:
2629 pair = self._make_java_param_pair(param, pyParamMap[param])
2630 paramMap.put([pair])
2631 return paramMap
2632
2633 def _transfer_param_map_from_java(self, javaParamMap):
2634 """
2635 Transforms a Java ParamMap into a Python ParamMap.
2636 """
2637 sc = SparkContext._active_spark_context
2638 paramMap = dict()
2639 for pair in javaParamMap.toList():
2640 param = pair.param()
2641 if self.hasParam(str(param.name())):
2642 if param.name() == "classifier":
2643 paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
2644 else:
2645 paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
2646 return paramMap
2647
2648
2649 class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
2650 """
2651 Model fitted by OneVsRest.
2652 This stores the models resulting from training k binary classifiers: one for each class.
2653 Each example is scored against all k models, and the model with the highest score
2654 is picked to label the example.
2655
2656 .. versionadded:: 2.0.0
2657 """
2658
2659 def setFeaturesCol(self, value):
2660 """
2661 Sets the value of :py:attr:`featuresCol`.
2662 """
2663 return self._set(featuresCol=value)
2664
2665 def setPredictionCol(self, value):
2666 """
2667 Sets the value of :py:attr:`predictionCol`.
2668 """
2669 return self._set(predictionCol=value)
2670
2671 def setRawPredictionCol(self, value):
2672 """
2673 Sets the value of :py:attr:`rawPredictionCol`.
2674 """
2675 return self._set(rawPredictionCol=value)
2676
2677 def __init__(self, models):
2678 super(OneVsRestModel, self).__init__()
2679 self.models = models
2680 java_models = [model._to_java() for model in self.models]
2681 sc = SparkContext._active_spark_context
2682 java_models_array = JavaWrapper._new_java_array(java_models,
2683 sc._gateway.jvm.org.apache.spark.ml
2684 .classification.ClassificationModel)
2685
2686 metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
2687 self._java_obj = \
2688 JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
2689 self.uid, metadata.empty(), java_models_array)
2690
2691 def _transform(self, dataset):
2692
2693 origCols = dataset.columns
2694
2695
2696 accColName = "mbc$acc" + str(uuid.uuid4())
2697 initUDF = udf(lambda _: [], ArrayType(DoubleType()))
2698 newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
2699
2700
2701 handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
2702 if handlePersistence:
2703 newDataset.persist(StorageLevel.MEMORY_AND_DISK)
2704
2705
2706 aggregatedDataset = newDataset
2707 for index, model in enumerate(self.models):
2708 rawPredictionCol = self.getRawPredictionCol()
2709
2710 columns = origCols + [rawPredictionCol, accColName]
2711
2712
2713 tmpColName = "mbc$tmp" + str(uuid.uuid4())
2714 updateUDF = udf(
2715 lambda predictions, prediction: predictions + [prediction.tolist()[1]],
2716 ArrayType(DoubleType()))
2717 transformedDataset = model.transform(aggregatedDataset).select(*columns)
2718 updatedDataset = transformedDataset.withColumn(
2719 tmpColName,
2720 updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
2721 newColumns = origCols + [tmpColName]
2722
2723
2724 aggregatedDataset = updatedDataset\
2725 .select(*newColumns).withColumnRenamed(tmpColName, accColName)
2726
2727 if handlePersistence:
2728 newDataset.unpersist()
2729
2730 if self.getRawPredictionCol():
2731 def func(predictions):
2732 predArray = []
2733 for x in predictions:
2734 predArray.append(x)
2735 return Vectors.dense(predArray)
2736
2737 rawPredictionUDF = udf(func)
2738 aggregatedDataset = aggregatedDataset.withColumn(
2739 self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
2740
2741 if self.getPredictionCol():
2742
2743 labelUDF = udf(lambda predictions: float(max(enumerate(predictions),
2744 key=operator.itemgetter(1))[0]), DoubleType())
2745 aggregatedDataset = aggregatedDataset.withColumn(
2746 self.getPredictionCol(), labelUDF(aggregatedDataset[accColName]))
2747 return aggregatedDataset.drop(accColName)
2748
2749 @since("2.0.0")
2750 def copy(self, extra=None):
2751 """
2752 Creates a copy of this instance with a randomly generated uid
2753 and some extra params. This creates a deep copy of the embedded paramMap,
2754 and copies the embedded and extra parameters over.
2755
2756 :param extra: Extra parameters to copy to the new instance
2757 :return: Copy of this instance
2758 """
2759 if extra is None:
2760 extra = dict()
2761 newModel = Params.copy(self, extra)
2762 newModel.models = [model.copy(extra) for model in self.models]
2763 return newModel
2764
2765 @classmethod
2766 def _from_java(cls, java_stage):
2767 """
2768 Given a Java OneVsRestModel, create and return a Python wrapper of it.
2769 Used for ML persistence.
2770 """
2771 featuresCol = java_stage.getFeaturesCol()
2772 labelCol = java_stage.getLabelCol()
2773 predictionCol = java_stage.getPredictionCol()
2774 classifier = JavaParams._from_java(java_stage.getClassifier())
2775 models = [JavaParams._from_java(model) for model in java_stage.models()]
2776 py_stage = cls(models=models).setPredictionCol(predictionCol)\
2777 .setFeaturesCol(featuresCol)
2778 py_stage._set(labelCol=labelCol)
2779 if java_stage.isDefined(java_stage.getParam("weightCol")):
2780 py_stage._set(weightCol=java_stage.getWeightCol())
2781 py_stage._set(classifier=classifier)
2782 py_stage._resetUid(java_stage.uid())
2783 return py_stage
2784
2785 def _to_java(self):
2786 """
2787 Transfer this instance to a Java OneVsRestModel. Used for ML persistence.
2788
2789 :return: Java object equivalent to this instance.
2790 """
2791 sc = SparkContext._active_spark_context
2792 java_models = [model._to_java() for model in self.models]
2793 java_models_array = JavaWrapper._new_java_array(
2794 java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
2795 metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
2796 _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
2797 self.uid, metadata.empty(), java_models_array)
2798 _java_obj.set("classifier", self.getClassifier()._to_java())
2799 _java_obj.set("featuresCol", self.getFeaturesCol())
2800 _java_obj.set("labelCol", self.getLabelCol())
2801 _java_obj.set("predictionCol", self.getPredictionCol())
2802 if (self.isDefined(self.weightCol) and self.getWeightCol()):
2803 _java_obj.set("weightCol", self.getWeightCol())
2804 return _java_obj
2805
2806
2807 @inherit_doc
2808 class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
2809 JavaMLReadable):
2810 """
2811 Factorization Machines learning algorithm for classification.
2812
2813 solver Supports:
2814
2815 * gd (normal mini-batch gradient descent)
2816 * adamW (default)
2817
2818 >>> from pyspark.ml.linalg import Vectors
2819 >>> from pyspark.ml.classification import FMClassifier
2820 >>> df = spark.createDataFrame([
2821 ... (1.0, Vectors.dense(1.0)),
2822 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
2823 >>> fm = FMClassifier(factorSize=2)
2824 >>> fm.setSeed(11)
2825 FMClassifier...
2826 >>> model = fm.fit(df)
2827 >>> model.getMaxIter()
2828 100
2829 >>> test0 = spark.createDataFrame([
2830 ... (Vectors.dense(-1.0),),
2831 ... (Vectors.dense(0.5),),
2832 ... (Vectors.dense(1.0),),
2833 ... (Vectors.dense(2.0),)], ["features"])
2834 >>> model.predictRaw(test0.head().features)
2835 DenseVector([22.13..., -22.13...])
2836 >>> model.predictProbability(test0.head().features)
2837 DenseVector([1.0, 0.0])
2838 >>> model.transform(test0).select("features", "probability").show(10, False)
2839 +--------+------------------------------------------+
2840 |features|probability |
2841 +--------+------------------------------------------+
2842 |[-1.0] |[0.9999999997574736,2.425264676902229E-10]|
2843 |[0.5] |[0.47627851732981163,0.5237214826701884] |
2844 |[1.0] |[5.491554426243495E-4,0.9994508445573757] |
2845 |[2.0] |[2.005766663870645E-10,0.9999999997994233]|
2846 +--------+------------------------------------------+
2847 ...
2848 >>> model.intercept
2849 -7.316665276826291
2850 >>> model.linear
2851 DenseVector([14.8232])
2852 >>> model.factors
2853 DenseMatrix(1, 2, [0.0163, -0.0051], 1)
2854
2855 .. versionadded:: 3.0.0
2856 """
2857
2858 @keyword_only
2859 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2860 probabilityCol="probability", rawPredictionCol="rawPrediction",
2861 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2862 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2863 tol=1e-6, solver="adamW", thresholds=None, seed=None):
2864 """
2865 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2866 probabilityCol="probability", rawPredictionCol="rawPrediction", \
2867 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2868 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2869 tol=1e-6, solver="adamW", thresholds=None, seed=None)
2870 """
2871 super(FMClassifier, self).__init__()
2872 self._java_obj = self._new_java_obj(
2873 "org.apache.spark.ml.classification.FMClassifier", self.uid)
2874 self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2875 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2876 tol=1e-6, solver="adamW")
2877 kwargs = self._input_kwargs
2878 self.setParams(**kwargs)
2879
2880 @keyword_only
2881 @since("3.0.0")
2882 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2883 probabilityCol="probability", rawPredictionCol="rawPrediction",
2884 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2885 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2886 tol=1e-6, solver="adamW", thresholds=None, seed=None):
2887 """
2888 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2889 probabilityCol="probability", rawPredictionCol="rawPrediction", \
2890 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2891 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2892 tol=1e-6, solver="adamW", thresholds=None, seed=None)
2893 Sets Params for FMClassifier.
2894 """
2895 kwargs = self._input_kwargs
2896 return self._set(**kwargs)
2897
2898 def _create_model(self, java_model):
2899 return FMClassificationModel(java_model)
2900
2901 @since("3.0.0")
2902 def setFactorSize(self, value):
2903 """
2904 Sets the value of :py:attr:`factorSize`.
2905 """
2906 return self._set(factorSize=value)
2907
2908 @since("3.0.0")
2909 def setFitLinear(self, value):
2910 """
2911 Sets the value of :py:attr:`fitLinear`.
2912 """
2913 return self._set(fitLinear=value)
2914
2915 @since("3.0.0")
2916 def setMiniBatchFraction(self, value):
2917 """
2918 Sets the value of :py:attr:`miniBatchFraction`.
2919 """
2920 return self._set(miniBatchFraction=value)
2921
2922 @since("3.0.0")
2923 def setInitStd(self, value):
2924 """
2925 Sets the value of :py:attr:`initStd`.
2926 """
2927 return self._set(initStd=value)
2928
2929 @since("3.0.0")
2930 def setMaxIter(self, value):
2931 """
2932 Sets the value of :py:attr:`maxIter`.
2933 """
2934 return self._set(maxIter=value)
2935
2936 @since("3.0.0")
2937 def setStepSize(self, value):
2938 """
2939 Sets the value of :py:attr:`stepSize`.
2940 """
2941 return self._set(stepSize=value)
2942
2943 @since("3.0.0")
2944 def setTol(self, value):
2945 """
2946 Sets the value of :py:attr:`tol`.
2947 """
2948 return self._set(tol=value)
2949
2950 @since("3.0.0")
2951 def setSolver(self, value):
2952 """
2953 Sets the value of :py:attr:`solver`.
2954 """
2955 return self._set(solver=value)
2956
2957 @since("3.0.0")
2958 def setSeed(self, value):
2959 """
2960 Sets the value of :py:attr:`seed`.
2961 """
2962 return self._set(seed=value)
2963
2964 @since("3.0.0")
2965 def setFitIntercept(self, value):
2966 """
2967 Sets the value of :py:attr:`fitIntercept`.
2968 """
2969 return self._set(fitIntercept=value)
2970
2971 @since("3.0.0")
2972 def setRegParam(self, value):
2973 """
2974 Sets the value of :py:attr:`regParam`.
2975 """
2976 return self._set(regParam=value)
2977
2978
2979 class FMClassificationModel(JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
2980 JavaMLWritable, JavaMLReadable):
2981 """
2982 Model fitted by :class:`FMClassifier`.
2983
2984 .. versionadded:: 3.0.0
2985 """
2986
2987 @property
2988 @since("3.0.0")
2989 def intercept(self):
2990 """
2991 Model intercept.
2992 """
2993 return self._call_java("intercept")
2994
2995 @property
2996 @since("3.0.0")
2997 def linear(self):
2998 """
2999 Model linear term.
3000 """
3001 return self._call_java("linear")
3002
3003 @property
3004 @since("3.0.0")
3005 def factors(self):
3006 """
3007 Model factor term.
3008 """
3009 return self._call_java("factors")
3010
3011
3012 if __name__ == "__main__":
3013 import doctest
3014 import pyspark.ml.classification
3015 from pyspark.sql import SparkSession
3016 globs = pyspark.ml.classification.__dict__.copy()
3017
3018
3019 spark = SparkSession.builder\
3020 .master("local[2]")\
3021 .appName("ml.classification tests")\
3022 .getOrCreate()
3023 sc = spark.sparkContext
3024 globs['sc'] = sc
3025 globs['spark'] = spark
3026 import tempfile
3027 temp_path = tempfile.mkdtemp()
3028 globs['temp_path'] = temp_path
3029 try:
3030 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
3031 spark.stop()
3032 finally:
3033 from shutil import rmtree
3034 try:
3035 rmtree(temp_path)
3036 except OSError:
3037 pass
3038 if failure_count:
3039 sys.exit(-1)