0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019
0020 from pyspark import since, keyword_only
0021 from pyspark.ml.param.shared import *
0022 from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
0023 _TreeEnsembleModel, _TreeEnsembleParams, _RandomForestParams, _GBTParams, \
0024 _HasVarianceImpurity, _TreeRegressorParams
0025 from pyspark.ml.util import *
0026 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
0027 JavaPredictor, JavaPredictionModel, _JavaPredictorParams, JavaWrapper
0028 from pyspark.ml.common import inherit_doc
0029 from pyspark.sql import DataFrame
0030
0031
0032 __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
0033 'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
0034 'GBTRegressor', 'GBTRegressionModel',
0035 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel',
0036 'GeneralizedLinearRegressionSummary', 'GeneralizedLinearRegressionTrainingSummary',
0037 'IsotonicRegression', 'IsotonicRegressionModel',
0038 'LinearRegression', 'LinearRegressionModel',
0039 'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
0040 'RandomForestRegressor', 'RandomForestRegressionModel',
0041 'FMRegressor', 'FMRegressionModel']
0042
0043
0044 class JavaRegressor(JavaPredictor, _JavaPredictorParams):
0045 """
0046 Java Regressor for regression tasks.
0047
0048 .. versionadded:: 3.0.0
0049 """
0050 pass
0051
0052
0053 class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams):
0054 """
0055 Java Model produced by a ``_JavaRegressor``.
0056 To be mixed in with :class:`pyspark.ml.JavaModel`
0057
0058 .. versionadded:: 3.0.0
0059 """
0060 pass
0061
0062
0063 class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
0064 HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver,
0065 HasAggregationDepth, HasLoss):
0066 """
0067 Params for :py:class:`LinearRegression` and :py:class:`LinearRegressionModel`.
0068
0069 .. versionadded:: 3.0.0
0070 """
0071
0072 solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
0073 "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
0074
0075 loss = Param(Params._dummy(), "loss", "The loss function to be optimized. Supported " +
0076 "options: squaredError, huber.", typeConverter=TypeConverters.toString)
0077
0078 epsilon = Param(Params._dummy(), "epsilon", "The shape parameter to control the amount of " +
0079 "robustness. Must be > 1.0. Only valid when loss is huber",
0080 typeConverter=TypeConverters.toFloat)
0081
0082 @since("2.3.0")
0083 def getEpsilon(self):
0084 """
0085 Gets the value of epsilon or its default value.
0086 """
0087 return self.getOrDefault(self.epsilon)
0088
0089
0090 @inherit_doc
0091 class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
0092 """
0093 Linear regression.
0094
0095 The learning objective is to minimize the specified loss function, with regularization.
0096 This supports two kinds of loss:
0097
0098 * squaredError (a.k.a squared loss)
0099 * huber (a hybrid of squared error for relatively small errors and absolute error for \
0100 relatively large ones, and we estimate the scale parameter from training data)
0101
0102 This supports multiple types of regularization:
0103
0104 * none (a.k.a. ordinary least squares)
0105 * L2 (ridge regression)
0106 * L1 (Lasso)
0107 * L2 + L1 (elastic net)
0108
0109 Note: Fitting with huber loss only supports none and L2 regularization.
0110
0111 >>> from pyspark.ml.linalg import Vectors
0112 >>> df = spark.createDataFrame([
0113 ... (1.0, 2.0, Vectors.dense(1.0)),
0114 ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
0115 >>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight")
0116 >>> lr.setMaxIter(5)
0117 LinearRegression...
0118 >>> lr.getMaxIter()
0119 5
0120 >>> lr.setRegParam(0.1)
0121 LinearRegression...
0122 >>> lr.getRegParam()
0123 0.1
0124 >>> lr.setRegParam(0.0)
0125 LinearRegression...
0126 >>> model = lr.fit(df)
0127 >>> model.setFeaturesCol("features")
0128 LinearRegressionModel...
0129 >>> model.setPredictionCol("newPrediction")
0130 LinearRegressionModel...
0131 >>> model.getMaxIter()
0132 5
0133 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
0134 >>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001
0135 True
0136 >>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001
0137 True
0138 >>> abs(model.coefficients[0] - 1.0) < 0.001
0139 True
0140 >>> abs(model.intercept - 0.0) < 0.001
0141 True
0142 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
0143 >>> abs(model.transform(test1).head().newPrediction - 1.0) < 0.001
0144 True
0145 >>> lr.setParams("vector")
0146 Traceback (most recent call last):
0147 ...
0148 TypeError: Method setParams forces keyword arguments.
0149 >>> lr_path = temp_path + "/lr"
0150 >>> lr.save(lr_path)
0151 >>> lr2 = LinearRegression.load(lr_path)
0152 >>> lr2.getMaxIter()
0153 5
0154 >>> model_path = temp_path + "/lr_model"
0155 >>> model.save(model_path)
0156 >>> model2 = LinearRegressionModel.load(model_path)
0157 >>> model.coefficients[0] == model2.coefficients[0]
0158 True
0159 >>> model.intercept == model2.intercept
0160 True
0161 >>> model.numFeatures
0162 1
0163 >>> model.write().format("pmml").save(model_path + "_2")
0164
0165 .. versionadded:: 1.4.0
0166 """
0167
0168 @keyword_only
0169 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0170 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0171 standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
0172 loss="squaredError", epsilon=1.35):
0173 """
0174 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0175 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0176 standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
0177 loss="squaredError", epsilon=1.35)
0178 """
0179 super(LinearRegression, self).__init__()
0180 self._java_obj = self._new_java_obj(
0181 "org.apache.spark.ml.regression.LinearRegression", self.uid)
0182 self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35)
0183 kwargs = self._input_kwargs
0184 self.setParams(**kwargs)
0185
0186 @keyword_only
0187 @since("1.4.0")
0188 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0189 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0190 standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
0191 loss="squaredError", epsilon=1.35):
0192 """
0193 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0194 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0195 standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
0196 loss="squaredError", epsilon=1.35)
0197 Sets params for linear regression.
0198 """
0199 kwargs = self._input_kwargs
0200 return self._set(**kwargs)
0201
0202 def _create_model(self, java_model):
0203 return LinearRegressionModel(java_model)
0204
0205 @since("2.3.0")
0206 def setEpsilon(self, value):
0207 """
0208 Sets the value of :py:attr:`epsilon`.
0209 """
0210 return self._set(epsilon=value)
0211
0212 def setMaxIter(self, value):
0213 """
0214 Sets the value of :py:attr:`maxIter`.
0215 """
0216 return self._set(maxIter=value)
0217
0218 def setRegParam(self, value):
0219 """
0220 Sets the value of :py:attr:`regParam`.
0221 """
0222 return self._set(regParam=value)
0223
0224 def setTol(self, value):
0225 """
0226 Sets the value of :py:attr:`tol`.
0227 """
0228 return self._set(tol=value)
0229
0230 def setElasticNetParam(self, value):
0231 """
0232 Sets the value of :py:attr:`elasticNetParam`.
0233 """
0234 return self._set(elasticNetParam=value)
0235
0236 def setFitIntercept(self, value):
0237 """
0238 Sets the value of :py:attr:`fitIntercept`.
0239 """
0240 return self._set(fitIntercept=value)
0241
0242 def setStandardization(self, value):
0243 """
0244 Sets the value of :py:attr:`standardization`.
0245 """
0246 return self._set(standardization=value)
0247
0248 def setWeightCol(self, value):
0249 """
0250 Sets the value of :py:attr:`weightCol`.
0251 """
0252 return self._set(weightCol=value)
0253
0254 def setSolver(self, value):
0255 """
0256 Sets the value of :py:attr:`solver`.
0257 """
0258 return self._set(solver=value)
0259
0260 def setAggregationDepth(self, value):
0261 """
0262 Sets the value of :py:attr:`aggregationDepth`.
0263 """
0264 return self._set(aggregationDepth=value)
0265
0266 def setLoss(self, value):
0267 """
0268 Sets the value of :py:attr:`loss`.
0269 """
0270 return self._set(lossType=value)
0271
0272
0273 class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
0274 JavaMLReadable, HasTrainingSummary):
0275 """
0276 Model fitted by :class:`LinearRegression`.
0277
0278 .. versionadded:: 1.4.0
0279 """
0280
0281 @property
0282 @since("2.0.0")
0283 def coefficients(self):
0284 """
0285 Model coefficients.
0286 """
0287 return self._call_java("coefficients")
0288
0289 @property
0290 @since("1.4.0")
0291 def intercept(self):
0292 """
0293 Model intercept.
0294 """
0295 return self._call_java("intercept")
0296
0297 @property
0298 @since("2.3.0")
0299 def scale(self):
0300 r"""
0301 The value by which :math:`\|y - X'w\|` is scaled down when loss is "huber", otherwise 1.0.
0302 """
0303 return self._call_java("scale")
0304
0305 @property
0306 @since("2.0.0")
0307 def summary(self):
0308 """
0309 Gets summary (e.g. residuals, mse, r-squared ) of model on
0310 training set. An exception is thrown if
0311 `trainingSummary is None`.
0312 """
0313 if self.hasSummary:
0314 return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
0315 else:
0316 raise RuntimeError("No training summary available for this %s" %
0317 self.__class__.__name__)
0318
0319 @since("2.0.0")
0320 def evaluate(self, dataset):
0321 """
0322 Evaluates the model on a test dataset.
0323
0324 :param dataset:
0325 Test dataset to evaluate model on, where dataset is an
0326 instance of :py:class:`pyspark.sql.DataFrame`
0327 """
0328 if not isinstance(dataset, DataFrame):
0329 raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
0330 java_lr_summary = self._call_java("evaluate", dataset)
0331 return LinearRegressionSummary(java_lr_summary)
0332
0333
0334 class LinearRegressionSummary(JavaWrapper):
0335 """
0336 Linear regression results evaluated on a dataset.
0337
0338 .. versionadded:: 2.0.0
0339 """
0340
0341 @property
0342 @since("2.0.0")
0343 def predictions(self):
0344 """
0345 Dataframe outputted by the model's `transform` method.
0346 """
0347 return self._call_java("predictions")
0348
0349 @property
0350 @since("2.0.0")
0351 def predictionCol(self):
0352 """
0353 Field in "predictions" which gives the predicted value of
0354 the label at each instance.
0355 """
0356 return self._call_java("predictionCol")
0357
0358 @property
0359 @since("2.0.0")
0360 def labelCol(self):
0361 """
0362 Field in "predictions" which gives the true label of each
0363 instance.
0364 """
0365 return self._call_java("labelCol")
0366
0367 @property
0368 @since("2.0.0")
0369 def featuresCol(self):
0370 """
0371 Field in "predictions" which gives the features of each instance
0372 as a vector.
0373 """
0374 return self._call_java("featuresCol")
0375
0376 @property
0377 @since("2.0.0")
0378 def explainedVariance(self):
0379 r"""
0380 Returns the explained variance regression score.
0381 explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}`
0382
0383 .. seealso:: `Wikipedia explain variation
0384 <http://en.wikipedia.org/wiki/Explained_variation>`_
0385
0386 .. note:: This ignores instance weights (setting all to 1.0) from
0387 `LinearRegression.weightCol`. This will change in later Spark
0388 versions.
0389 """
0390 return self._call_java("explainedVariance")
0391
0392 @property
0393 @since("2.0.0")
0394 def meanAbsoluteError(self):
0395 """
0396 Returns the mean absolute error, which is a risk function
0397 corresponding to the expected value of the absolute error
0398 loss or l1-norm loss.
0399
0400 .. note:: This ignores instance weights (setting all to 1.0) from
0401 `LinearRegression.weightCol`. This will change in later Spark
0402 versions.
0403 """
0404 return self._call_java("meanAbsoluteError")
0405
0406 @property
0407 @since("2.0.0")
0408 def meanSquaredError(self):
0409 """
0410 Returns the mean squared error, which is a risk function
0411 corresponding to the expected value of the squared error
0412 loss or quadratic loss.
0413
0414 .. note:: This ignores instance weights (setting all to 1.0) from
0415 `LinearRegression.weightCol`. This will change in later Spark
0416 versions.
0417 """
0418 return self._call_java("meanSquaredError")
0419
0420 @property
0421 @since("2.0.0")
0422 def rootMeanSquaredError(self):
0423 """
0424 Returns the root mean squared error, which is defined as the
0425 square root of the mean squared error.
0426
0427 .. note:: This ignores instance weights (setting all to 1.0) from
0428 `LinearRegression.weightCol`. This will change in later Spark
0429 versions.
0430 """
0431 return self._call_java("rootMeanSquaredError")
0432
0433 @property
0434 @since("2.0.0")
0435 def r2(self):
0436 """
0437 Returns R^2, the coefficient of determination.
0438
0439 .. seealso:: `Wikipedia coefficient of determination
0440 <http://en.wikipedia.org/wiki/Coefficient_of_determination>`_
0441
0442 .. note:: This ignores instance weights (setting all to 1.0) from
0443 `LinearRegression.weightCol`. This will change in later Spark
0444 versions.
0445 """
0446 return self._call_java("r2")
0447
0448 @property
0449 @since("2.4.0")
0450 def r2adj(self):
0451 """
0452 Returns Adjusted R^2, the adjusted coefficient of determination.
0453
0454 .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2
0455 <https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2>`_
0456
0457 .. note:: This ignores instance weights (setting all to 1.0) from
0458 `LinearRegression.weightCol`. This will change in later Spark versions.
0459 """
0460 return self._call_java("r2adj")
0461
0462 @property
0463 @since("2.0.0")
0464 def residuals(self):
0465 """
0466 Residuals (label - predicted value)
0467 """
0468 return self._call_java("residuals")
0469
0470 @property
0471 @since("2.0.0")
0472 def numInstances(self):
0473 """
0474 Number of instances in DataFrame predictions
0475 """
0476 return self._call_java("numInstances")
0477
0478 @property
0479 @since("2.2.0")
0480 def degreesOfFreedom(self):
0481 """
0482 Degrees of freedom.
0483 """
0484 return self._call_java("degreesOfFreedom")
0485
0486 @property
0487 @since("2.0.0")
0488 def devianceResiduals(self):
0489 """
0490 The weighted residuals, the usual residuals rescaled by the
0491 square root of the instance weights.
0492 """
0493 return self._call_java("devianceResiduals")
0494
0495 @property
0496 @since("2.0.0")
0497 def coefficientStandardErrors(self):
0498 """
0499 Standard error of estimated coefficients and intercept.
0500 This value is only available when using the "normal" solver.
0501
0502 If :py:attr:`LinearRegression.fitIntercept` is set to True,
0503 then the last element returned corresponds to the intercept.
0504
0505 .. seealso:: :py:attr:`LinearRegression.solver`
0506 """
0507 return self._call_java("coefficientStandardErrors")
0508
0509 @property
0510 @since("2.0.0")
0511 def tValues(self):
0512 """
0513 T-statistic of estimated coefficients and intercept.
0514 This value is only available when using the "normal" solver.
0515
0516 If :py:attr:`LinearRegression.fitIntercept` is set to True,
0517 then the last element returned corresponds to the intercept.
0518
0519 .. seealso:: :py:attr:`LinearRegression.solver`
0520 """
0521 return self._call_java("tValues")
0522
0523 @property
0524 @since("2.0.0")
0525 def pValues(self):
0526 """
0527 Two-sided p-value of estimated coefficients and intercept.
0528 This value is only available when using the "normal" solver.
0529
0530 If :py:attr:`LinearRegression.fitIntercept` is set to True,
0531 then the last element returned corresponds to the intercept.
0532
0533 .. seealso:: :py:attr:`LinearRegression.solver`
0534 """
0535 return self._call_java("pValues")
0536
0537
0538 @inherit_doc
0539 class LinearRegressionTrainingSummary(LinearRegressionSummary):
0540 """
0541 Linear regression training results. Currently, the training summary ignores the
0542 training weights except for the objective trace.
0543
0544 .. versionadded:: 2.0.0
0545 """
0546
0547 @property
0548 @since("2.0.0")
0549 def objectiveHistory(self):
0550 """
0551 Objective function (scaled loss + regularization) at each
0552 iteration.
0553 This value is only available when using the "l-bfgs" solver.
0554
0555 .. seealso:: :py:attr:`LinearRegression.solver`
0556 """
0557 return self._call_java("objectiveHistory")
0558
0559 @property
0560 @since("2.0.0")
0561 def totalIterations(self):
0562 """
0563 Number of training iterations until termination.
0564 This value is only available when using the "l-bfgs" solver.
0565
0566 .. seealso:: :py:attr:`LinearRegression.solver`
0567 """
0568 return self._call_java("totalIterations")
0569
0570
0571 class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol):
0572 """
0573 Params for :py:class:`IsotonicRegression` and :py:class:`IsotonicRegressionModel`.
0574
0575 .. versionadded:: 3.0.0
0576 """
0577
0578 isotonic = Param(
0579 Params._dummy(), "isotonic",
0580 "whether the output sequence should be isotonic/increasing (true) or" +
0581 "antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean)
0582 featureIndex = Param(
0583 Params._dummy(), "featureIndex",
0584 "The index of the feature if featuresCol is a vector column, no effect otherwise.",
0585 typeConverter=TypeConverters.toInt)
0586
0587 def getIsotonic(self):
0588 """
0589 Gets the value of isotonic or its default value.
0590 """
0591 return self.getOrDefault(self.isotonic)
0592
0593 def getFeatureIndex(self):
0594 """
0595 Gets the value of featureIndex or its default value.
0596 """
0597 return self.getOrDefault(self.featureIndex)
0598
0599
0600 @inherit_doc
0601 class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol,
0602 JavaMLWritable, JavaMLReadable):
0603 """
0604 Currently implemented using parallelized pool adjacent violators algorithm.
0605 Only univariate (single feature) algorithm supported.
0606
0607 >>> from pyspark.ml.linalg import Vectors
0608 >>> df = spark.createDataFrame([
0609 ... (1.0, Vectors.dense(1.0)),
0610 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
0611 >>> ir = IsotonicRegression()
0612 >>> model = ir.fit(df)
0613 >>> model.setFeaturesCol("features")
0614 IsotonicRegressionModel...
0615 >>> model.numFeatures
0616 1
0617 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
0618 >>> model.transform(test0).head().prediction
0619 0.0
0620 >>> model.predict(test0.head().features[model.getFeatureIndex()])
0621 0.0
0622 >>> model.boundaries
0623 DenseVector([0.0, 1.0])
0624 >>> ir_path = temp_path + "/ir"
0625 >>> ir.save(ir_path)
0626 >>> ir2 = IsotonicRegression.load(ir_path)
0627 >>> ir2.getIsotonic()
0628 True
0629 >>> model_path = temp_path + "/ir_model"
0630 >>> model.save(model_path)
0631 >>> model2 = IsotonicRegressionModel.load(model_path)
0632 >>> model.boundaries == model2.boundaries
0633 True
0634 >>> model.predictions == model2.predictions
0635 True
0636
0637 .. versionadded:: 1.6.0
0638 """
0639 @keyword_only
0640 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0641 weightCol=None, isotonic=True, featureIndex=0):
0642 """
0643 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0644 weightCol=None, isotonic=True, featureIndex=0):
0645 """
0646 super(IsotonicRegression, self).__init__()
0647 self._java_obj = self._new_java_obj(
0648 "org.apache.spark.ml.regression.IsotonicRegression", self.uid)
0649 self._setDefault(isotonic=True, featureIndex=0)
0650 kwargs = self._input_kwargs
0651 self.setParams(**kwargs)
0652
0653 @keyword_only
0654 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0655 weightCol=None, isotonic=True, featureIndex=0):
0656 """
0657 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0658 weightCol=None, isotonic=True, featureIndex=0):
0659 Set the params for IsotonicRegression.
0660 """
0661 kwargs = self._input_kwargs
0662 return self._set(**kwargs)
0663
0664 def _create_model(self, java_model):
0665 return IsotonicRegressionModel(java_model)
0666
0667 def setIsotonic(self, value):
0668 """
0669 Sets the value of :py:attr:`isotonic`.
0670 """
0671 return self._set(isotonic=value)
0672
0673 def setFeatureIndex(self, value):
0674 """
0675 Sets the value of :py:attr:`featureIndex`.
0676 """
0677 return self._set(featureIndex=value)
0678
0679 @since("1.6.0")
0680 def setFeaturesCol(self, value):
0681 """
0682 Sets the value of :py:attr:`featuresCol`.
0683 """
0684 return self._set(featuresCol=value)
0685
0686 @since("1.6.0")
0687 def setPredictionCol(self, value):
0688 """
0689 Sets the value of :py:attr:`predictionCol`.
0690 """
0691 return self._set(predictionCol=value)
0692
0693 @since("1.6.0")
0694 def setLabelCol(self, value):
0695 """
0696 Sets the value of :py:attr:`labelCol`.
0697 """
0698 return self._set(labelCol=value)
0699
0700 @since("1.6.0")
0701 def setWeightCol(self, value):
0702 """
0703 Sets the value of :py:attr:`weightCol`.
0704 """
0705 return self._set(weightCol=value)
0706
0707
0708 class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritable,
0709 JavaMLReadable):
0710 """
0711 Model fitted by :class:`IsotonicRegression`.
0712
0713 .. versionadded:: 1.6.0
0714 """
0715
0716 @since("3.0.0")
0717 def setFeaturesCol(self, value):
0718 """
0719 Sets the value of :py:attr:`featuresCol`.
0720 """
0721 return self._set(featuresCol=value)
0722
0723 @since("3.0.0")
0724 def setPredictionCol(self, value):
0725 """
0726 Sets the value of :py:attr:`predictionCol`.
0727 """
0728 return self._set(predictionCol=value)
0729
0730 def setFeatureIndex(self, value):
0731 """
0732 Sets the value of :py:attr:`featureIndex`.
0733 """
0734 return self._set(featureIndex=value)
0735
0736 @property
0737 @since("1.6.0")
0738 def boundaries(self):
0739 """
0740 Boundaries in increasing order for which predictions are known.
0741 """
0742 return self._call_java("boundaries")
0743
0744 @property
0745 @since("1.6.0")
0746 def predictions(self):
0747 """
0748 Predictions associated with the boundaries at the same index, monotone because of isotonic
0749 regression.
0750 """
0751 return self._call_java("predictions")
0752
0753 @property
0754 @since("3.0.0")
0755 def numFeatures(self):
0756 """
0757 Returns the number of features the model was trained on. If unknown, returns -1
0758 """
0759 return self._call_java("numFeatures")
0760
0761 @since("3.0.0")
0762 def predict(self, value):
0763 """
0764 Predict label for the given features.
0765 """
0766 return self._call_java("predict", value)
0767
0768
0769 class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol):
0770 """
0771 Params for :py:class:`DecisionTreeRegressor` and :py:class:`DecisionTreeRegressionModel`.
0772
0773 .. versionadded:: 3.0.0
0774 """
0775
0776 pass
0777
0778
0779 @inherit_doc
0780 class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable,
0781 JavaMLReadable):
0782 """
0783 `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
0784 learning algorithm for regression.
0785 It supports both continuous and categorical features.
0786
0787 >>> from pyspark.ml.linalg import Vectors
0788 >>> df = spark.createDataFrame([
0789 ... (1.0, Vectors.dense(1.0)),
0790 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
0791 >>> dt = DecisionTreeRegressor(maxDepth=2)
0792 >>> dt.setVarianceCol("variance")
0793 DecisionTreeRegressor...
0794 >>> model = dt.fit(df)
0795 >>> model.getVarianceCol()
0796 'variance'
0797 >>> model.setLeafCol("leafId")
0798 DecisionTreeRegressionModel...
0799 >>> model.depth
0800 1
0801 >>> model.numNodes
0802 3
0803 >>> model.featureImportances
0804 SparseVector(1, {0: 1.0})
0805 >>> model.numFeatures
0806 1
0807 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
0808 >>> model.predict(test0.head().features)
0809 0.0
0810 >>> result = model.transform(test0).head()
0811 >>> result.prediction
0812 0.0
0813 >>> model.predictLeaf(test0.head().features)
0814 0.0
0815 >>> result.leafId
0816 0.0
0817 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
0818 >>> model.transform(test1).head().prediction
0819 1.0
0820 >>> dtr_path = temp_path + "/dtr"
0821 >>> dt.save(dtr_path)
0822 >>> dt2 = DecisionTreeRegressor.load(dtr_path)
0823 >>> dt2.getMaxDepth()
0824 2
0825 >>> model_path = temp_path + "/dtr_model"
0826 >>> model.save(model_path)
0827 >>> model2 = DecisionTreeRegressionModel.load(model_path)
0828 >>> model.numNodes == model2.numNodes
0829 True
0830 >>> model.depth == model2.depth
0831 True
0832 >>> model.transform(test1).head().variance
0833 0.0
0834
0835 >>> df3 = spark.createDataFrame([
0836 ... (1.0, 0.2, Vectors.dense(1.0)),
0837 ... (1.0, 0.8, Vectors.dense(1.0)),
0838 ... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
0839 >>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol="variance")
0840 >>> model3 = dt3.fit(df3)
0841 >>> print(model3.toDebugString)
0842 DecisionTreeRegressionModel...depth=1, numNodes=3...
0843
0844 .. versionadded:: 1.4.0
0845 """
0846
0847 @keyword_only
0848 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0849 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
0850 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
0851 seed=None, varianceCol=None, weightCol=None, leafCol="",
0852 minWeightFractionPerNode=0.0):
0853 """
0854 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0855 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
0856 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
0857 impurity="variance", seed=None, varianceCol=None, weightCol=None, \
0858 leafCol="", minWeightFractionPerNode=0.0)
0859 """
0860 super(DecisionTreeRegressor, self).__init__()
0861 self._java_obj = self._new_java_obj(
0862 "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
0863 self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
0864 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
0865 impurity="variance", leafCol="", minWeightFractionPerNode=0.0)
0866 kwargs = self._input_kwargs
0867 self.setParams(**kwargs)
0868
0869 @keyword_only
0870 @since("1.4.0")
0871 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0872 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
0873 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
0874 impurity="variance", seed=None, varianceCol=None, weightCol=None,
0875 leafCol="", minWeightFractionPerNode=0.0):
0876 """
0877 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0878 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
0879 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
0880 impurity="variance", seed=None, varianceCol=None, weightCol=None, \
0881 leafCol="", minWeightFractionPerNode=0.0)
0882 Sets params for the DecisionTreeRegressor.
0883 """
0884 kwargs = self._input_kwargs
0885 return self._set(**kwargs)
0886
0887 def _create_model(self, java_model):
0888 return DecisionTreeRegressionModel(java_model)
0889
0890 @since("1.4.0")
0891 def setMaxDepth(self, value):
0892 """
0893 Sets the value of :py:attr:`maxDepth`.
0894 """
0895 return self._set(maxDepth=value)
0896
0897 @since("1.4.0")
0898 def setMaxBins(self, value):
0899 """
0900 Sets the value of :py:attr:`maxBins`.
0901 """
0902 return self._set(maxBins=value)
0903
0904 @since("1.4.0")
0905 def setMinInstancesPerNode(self, value):
0906 """
0907 Sets the value of :py:attr:`minInstancesPerNode`.
0908 """
0909 return self._set(minInstancesPerNode=value)
0910
0911 @since("3.0.0")
0912 def setMinWeightFractionPerNode(self, value):
0913 """
0914 Sets the value of :py:attr:`minWeightFractionPerNode`.
0915 """
0916 return self._set(minWeightFractionPerNode=value)
0917
0918 @since("1.4.0")
0919 def setMinInfoGain(self, value):
0920 """
0921 Sets the value of :py:attr:`minInfoGain`.
0922 """
0923 return self._set(minInfoGain=value)
0924
0925 @since("1.4.0")
0926 def setMaxMemoryInMB(self, value):
0927 """
0928 Sets the value of :py:attr:`maxMemoryInMB`.
0929 """
0930 return self._set(maxMemoryInMB=value)
0931
0932 @since("1.4.0")
0933 def setCacheNodeIds(self, value):
0934 """
0935 Sets the value of :py:attr:`cacheNodeIds`.
0936 """
0937 return self._set(cacheNodeIds=value)
0938
0939 @since("1.4.0")
0940 def setImpurity(self, value):
0941 """
0942 Sets the value of :py:attr:`impurity`.
0943 """
0944 return self._set(impurity=value)
0945
0946 @since("1.4.0")
0947 def setCheckpointInterval(self, value):
0948 """
0949 Sets the value of :py:attr:`checkpointInterval`.
0950 """
0951 return self._set(checkpointInterval=value)
0952
0953 def setSeed(self, value):
0954 """
0955 Sets the value of :py:attr:`seed`.
0956 """
0957 return self._set(seed=value)
0958
0959 @since("3.0.0")
0960 def setWeightCol(self, value):
0961 """
0962 Sets the value of :py:attr:`weightCol`.
0963 """
0964 return self._set(weightCol=value)
0965
0966 @since("2.0.0")
0967 def setVarianceCol(self, value):
0968 """
0969 Sets the value of :py:attr:`varianceCol`.
0970 """
0971 return self._set(varianceCol=value)
0972
0973
0974 @inherit_doc
0975 class DecisionTreeRegressionModel(
0976 JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams,
0977 JavaMLWritable, JavaMLReadable
0978 ):
0979 """
0980 Model fitted by :class:`DecisionTreeRegressor`.
0981
0982 .. versionadded:: 1.4.0
0983 """
0984
0985 @since("3.0.0")
0986 def setVarianceCol(self, value):
0987 """
0988 Sets the value of :py:attr:`varianceCol`.
0989 """
0990 return self._set(varianceCol=value)
0991
0992 @property
0993 @since("2.0.0")
0994 def featureImportances(self):
0995 """
0996 Estimate of the importance of each feature.
0997
0998 This generalizes the idea of "Gini" importance to other losses,
0999 following the explanation of Gini importance from "Random Forests" documentation
1000 by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
1001
1002 This feature importance is calculated as follows:
1003 - importance(feature j) = sum (over nodes which split on feature j) of the gain,
1004 where gain is scaled by the number of instances passing through node
1005 - Normalize importances for tree to sum to 1.
1006
1007 .. note:: Feature importance for single decision trees can have high variance due to
1008 correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
1009 to determine feature importance instead.
1010 """
1011 return self._call_java("featureImportances")
1012
1013
1014 class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
1015 """
1016 Params for :py:class:`RandomForestRegressor` and :py:class:`RandomForestRegressionModel`.
1017
1018 .. versionadded:: 3.0.0
1019 """
1020 pass
1021
1022
1023 @inherit_doc
1024 class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable,
1025 JavaMLReadable):
1026 """
1027 `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
1028 learning algorithm for regression.
1029 It supports both continuous and categorical features.
1030
1031 >>> from numpy import allclose
1032 >>> from pyspark.ml.linalg import Vectors
1033 >>> df = spark.createDataFrame([
1034 ... (1.0, Vectors.dense(1.0)),
1035 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1036 >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
1037 >>> rf.getMinWeightFractionPerNode()
1038 0.0
1039 >>> rf.setSeed(42)
1040 RandomForestRegressor...
1041 >>> model = rf.fit(df)
1042 >>> model.getBootstrap()
1043 True
1044 >>> model.getSeed()
1045 42
1046 >>> model.setLeafCol("leafId")
1047 RandomForestRegressionModel...
1048 >>> model.featureImportances
1049 SparseVector(1, {0: 1.0})
1050 >>> allclose(model.treeWeights, [1.0, 1.0])
1051 True
1052 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1053 >>> model.predict(test0.head().features)
1054 0.0
1055 >>> model.predictLeaf(test0.head().features)
1056 DenseVector([0.0, 0.0])
1057 >>> result = model.transform(test0).head()
1058 >>> result.prediction
1059 0.0
1060 >>> result.leafId
1061 DenseVector([0.0, 0.0])
1062 >>> model.numFeatures
1063 1
1064 >>> model.trees
1065 [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1066 >>> model.getNumTrees
1067 2
1068 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1069 >>> model.transform(test1).head().prediction
1070 0.5
1071 >>> rfr_path = temp_path + "/rfr"
1072 >>> rf.save(rfr_path)
1073 >>> rf2 = RandomForestRegressor.load(rfr_path)
1074 >>> rf2.getNumTrees()
1075 2
1076 >>> model_path = temp_path + "/rfr_model"
1077 >>> model.save(model_path)
1078 >>> model2 = RandomForestRegressionModel.load(model_path)
1079 >>> model.featureImportances == model2.featureImportances
1080 True
1081
1082 .. versionadded:: 1.4.0
1083 """
1084
1085 @keyword_only
1086 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1087 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1088 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1089 impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
1090 featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1091 weightCol=None, bootstrap=True):
1092 """
1093 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1094 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1095 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1096 impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1097 featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0", \
1098 weightCol=None, bootstrap=True)
1099 """
1100 super(RandomForestRegressor, self).__init__()
1101 self._java_obj = self._new_java_obj(
1102 "org.apache.spark.ml.regression.RandomForestRegressor", self.uid)
1103 self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1104 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1105 impurity="variance", subsamplingRate=1.0, numTrees=20,
1106 featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1107 bootstrap=True)
1108 kwargs = self._input_kwargs
1109 self.setParams(**kwargs)
1110
1111 @keyword_only
1112 @since("1.4.0")
1113 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1114 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1115 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1116 impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
1117 featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1118 weightCol=None, bootstrap=True):
1119 """
1120 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1121 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1122 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1123 impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1124 featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0, \
1125 weightCol=None, bootstrap=True)
1126 Sets params for linear regression.
1127 """
1128 kwargs = self._input_kwargs
1129 return self._set(**kwargs)
1130
1131 def _create_model(self, java_model):
1132 return RandomForestRegressionModel(java_model)
1133
1134 def setMaxDepth(self, value):
1135 """
1136 Sets the value of :py:attr:`maxDepth`.
1137 """
1138 return self._set(maxDepth=value)
1139
1140 def setMaxBins(self, value):
1141 """
1142 Sets the value of :py:attr:`maxBins`.
1143 """
1144 return self._set(maxBins=value)
1145
1146 def setMinInstancesPerNode(self, value):
1147 """
1148 Sets the value of :py:attr:`minInstancesPerNode`.
1149 """
1150 return self._set(minInstancesPerNode=value)
1151
1152 def setMinInfoGain(self, value):
1153 """
1154 Sets the value of :py:attr:`minInfoGain`.
1155 """
1156 return self._set(minInfoGain=value)
1157
1158 def setMaxMemoryInMB(self, value):
1159 """
1160 Sets the value of :py:attr:`maxMemoryInMB`.
1161 """
1162 return self._set(maxMemoryInMB=value)
1163
1164 def setCacheNodeIds(self, value):
1165 """
1166 Sets the value of :py:attr:`cacheNodeIds`.
1167 """
1168 return self._set(cacheNodeIds=value)
1169
1170 @since("1.4.0")
1171 def setImpurity(self, value):
1172 """
1173 Sets the value of :py:attr:`impurity`.
1174 """
1175 return self._set(impurity=value)
1176
1177 @since("1.4.0")
1178 def setNumTrees(self, value):
1179 """
1180 Sets the value of :py:attr:`numTrees`.
1181 """
1182 return self._set(numTrees=value)
1183
1184 @since("3.0.0")
1185 def setBootstrap(self, value):
1186 """
1187 Sets the value of :py:attr:`bootstrap`.
1188 """
1189 return self._set(bootstrap=value)
1190
1191 @since("1.4.0")
1192 def setSubsamplingRate(self, value):
1193 """
1194 Sets the value of :py:attr:`subsamplingRate`.
1195 """
1196 return self._set(subsamplingRate=value)
1197
1198 @since("2.4.0")
1199 def setFeatureSubsetStrategy(self, value):
1200 """
1201 Sets the value of :py:attr:`featureSubsetStrategy`.
1202 """
1203 return self._set(featureSubsetStrategy=value)
1204
1205 def setCheckpointInterval(self, value):
1206 """
1207 Sets the value of :py:attr:`checkpointInterval`.
1208 """
1209 return self._set(checkpointInterval=value)
1210
1211 def setSeed(self, value):
1212 """
1213 Sets the value of :py:attr:`seed`.
1214 """
1215 return self._set(seed=value)
1216
1217 @since("3.0.0")
1218 def setWeightCol(self, value):
1219 """
1220 Sets the value of :py:attr:`weightCol`.
1221 """
1222 return self._set(weightCol=value)
1223
1224 @since("3.0.0")
1225 def setMinWeightFractionPerNode(self, value):
1226 """
1227 Sets the value of :py:attr:`minWeightFractionPerNode`.
1228 """
1229 return self._set(minWeightFractionPerNode=value)
1230
1231
1232 class RandomForestRegressionModel(
1233 JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams,
1234 JavaMLWritable, JavaMLReadable
1235 ):
1236 """
1237 Model fitted by :class:`RandomForestRegressor`.
1238
1239 .. versionadded:: 1.4.0
1240 """
1241
1242 @property
1243 @since("2.0.0")
1244 def trees(self):
1245 """Trees in this ensemble. Warning: These have null parent Estimators."""
1246 return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1247
1248 @property
1249 @since("2.0.0")
1250 def featureImportances(self):
1251 """
1252 Estimate of the importance of each feature.
1253
1254 Each feature's importance is the average of its importance across all trees in the ensemble
1255 The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1256 (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1257 and follows the implementation from scikit-learn.
1258
1259 .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
1260 """
1261 return self._call_java("featureImportances")
1262
1263
1264 class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
1265 """
1266 Params for :py:class:`GBTRegressor` and :py:class:`GBTRegressorModel`.
1267
1268 .. versionadded:: 3.0.0
1269 """
1270
1271 supportedLossTypes = ["squared", "absolute"]
1272
1273 lossType = Param(Params._dummy(), "lossType",
1274 "Loss function which GBT tries to minimize (case-insensitive). " +
1275 "Supported options: " + ", ".join(supportedLossTypes),
1276 typeConverter=TypeConverters.toString)
1277
1278 @since("1.4.0")
1279 def getLossType(self):
1280 """
1281 Gets the value of lossType or its default value.
1282 """
1283 return self.getOrDefault(self.lossType)
1284
1285
1286 @inherit_doc
1287 class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
1288 """
1289 `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
1290 learning algorithm for regression.
1291 It supports both continuous and categorical features.
1292
1293 >>> from numpy import allclose
1294 >>> from pyspark.ml.linalg import Vectors
1295 >>> df = spark.createDataFrame([
1296 ... (1.0, Vectors.dense(1.0)),
1297 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1298 >>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId")
1299 >>> gbt.setMaxIter(5)
1300 GBTRegressor...
1301 >>> gbt.setMinWeightFractionPerNode(0.049)
1302 GBTRegressor...
1303 >>> gbt.getMaxIter()
1304 5
1305 >>> print(gbt.getImpurity())
1306 variance
1307 >>> print(gbt.getFeatureSubsetStrategy())
1308 all
1309 >>> model = gbt.fit(df)
1310 >>> model.featureImportances
1311 SparseVector(1, {0: 1.0})
1312 >>> model.numFeatures
1313 1
1314 >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
1315 True
1316 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1317 >>> model.predict(test0.head().features)
1318 0.0
1319 >>> model.predictLeaf(test0.head().features)
1320 DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1321 >>> result = model.transform(test0).head()
1322 >>> result.prediction
1323 0.0
1324 >>> result.leafId
1325 DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1326 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1327 >>> model.transform(test1).head().prediction
1328 1.0
1329 >>> gbtr_path = temp_path + "gbtr"
1330 >>> gbt.save(gbtr_path)
1331 >>> gbt2 = GBTRegressor.load(gbtr_path)
1332 >>> gbt2.getMaxDepth()
1333 2
1334 >>> model_path = temp_path + "gbtr_model"
1335 >>> model.save(model_path)
1336 >>> model2 = GBTRegressionModel.load(model_path)
1337 >>> model.featureImportances == model2.featureImportances
1338 True
1339 >>> model.treeWeights == model2.treeWeights
1340 True
1341 >>> model.trees
1342 [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1343 >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
1344 ... ["label", "features"])
1345 >>> model.evaluateEachIteration(validation, "squared")
1346 [0.0, 0.0, 0.0, 0.0, 0.0]
1347 >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
1348 >>> gbt.getValidationIndicatorCol()
1349 'validationIndicator'
1350 >>> gbt.getValidationTol()
1351 0.01
1352
1353 .. versionadded:: 1.4.0
1354 """
1355
1356 @keyword_only
1357 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1358 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1359 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
1360 checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
1361 impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1362 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1363 weightCol=None):
1364 """
1365 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1366 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1367 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
1368 checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
1369 impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1370 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1371 weightCol=None)
1372 """
1373 super(GBTRegressor, self).__init__()
1374 self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
1375 self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1376 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
1377 checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
1378 impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1379 leafCol="", minWeightFractionPerNode=0.0)
1380 kwargs = self._input_kwargs
1381 self.setParams(**kwargs)
1382
1383 @keyword_only
1384 @since("1.4.0")
1385 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1386 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1387 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
1388 checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
1389 impuriy="variance", featureSubsetStrategy="all", validationTol=0.01,
1390 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1391 weightCol=None):
1392 """
1393 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1394 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1395 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
1396 checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
1397 impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1398 validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1399 weightCol=None)
1400 Sets params for Gradient Boosted Tree Regression.
1401 """
1402 kwargs = self._input_kwargs
1403 return self._set(**kwargs)
1404
1405 def _create_model(self, java_model):
1406 return GBTRegressionModel(java_model)
1407
1408 @since("1.4.0")
1409 def setMaxDepth(self, value):
1410 """
1411 Sets the value of :py:attr:`maxDepth`.
1412 """
1413 return self._set(maxDepth=value)
1414
1415 @since("1.4.0")
1416 def setMaxBins(self, value):
1417 """
1418 Sets the value of :py:attr:`maxBins`.
1419 """
1420 return self._set(maxBins=value)
1421
1422 @since("1.4.0")
1423 def setMinInstancesPerNode(self, value):
1424 """
1425 Sets the value of :py:attr:`minInstancesPerNode`.
1426 """
1427 return self._set(minInstancesPerNode=value)
1428
1429 @since("1.4.0")
1430 def setMinInfoGain(self, value):
1431 """
1432 Sets the value of :py:attr:`minInfoGain`.
1433 """
1434 return self._set(minInfoGain=value)
1435
1436 @since("1.4.0")
1437 def setMaxMemoryInMB(self, value):
1438 """
1439 Sets the value of :py:attr:`maxMemoryInMB`.
1440 """
1441 return self._set(maxMemoryInMB=value)
1442
1443 @since("1.4.0")
1444 def setCacheNodeIds(self, value):
1445 """
1446 Sets the value of :py:attr:`cacheNodeIds`.
1447 """
1448 return self._set(cacheNodeIds=value)
1449
1450 @since("1.4.0")
1451 def setImpurity(self, value):
1452 """
1453 Sets the value of :py:attr:`impurity`.
1454 """
1455 return self._set(impurity=value)
1456
1457 @since("1.4.0")
1458 def setLossType(self, value):
1459 """
1460 Sets the value of :py:attr:`lossType`.
1461 """
1462 return self._set(lossType=value)
1463
1464 @since("1.4.0")
1465 def setSubsamplingRate(self, value):
1466 """
1467 Sets the value of :py:attr:`subsamplingRate`.
1468 """
1469 return self._set(subsamplingRate=value)
1470
1471 @since("2.4.0")
1472 def setFeatureSubsetStrategy(self, value):
1473 """
1474 Sets the value of :py:attr:`featureSubsetStrategy`.
1475 """
1476 return self._set(featureSubsetStrategy=value)
1477
1478 @since("3.0.0")
1479 def setValidationIndicatorCol(self, value):
1480 """
1481 Sets the value of :py:attr:`validationIndicatorCol`.
1482 """
1483 return self._set(validationIndicatorCol=value)
1484
1485 @since("1.4.0")
1486 def setMaxIter(self, value):
1487 """
1488 Sets the value of :py:attr:`maxIter`.
1489 """
1490 return self._set(maxIter=value)
1491
1492 @since("1.4.0")
1493 def setCheckpointInterval(self, value):
1494 """
1495 Sets the value of :py:attr:`checkpointInterval`.
1496 """
1497 return self._set(checkpointInterval=value)
1498
1499 @since("1.4.0")
1500 def setSeed(self, value):
1501 """
1502 Sets the value of :py:attr:`seed`.
1503 """
1504 return self._set(seed=value)
1505
1506 @since("1.4.0")
1507 def setStepSize(self, value):
1508 """
1509 Sets the value of :py:attr:`stepSize`.
1510 """
1511 return self._set(stepSize=value)
1512
1513 @since("3.0.0")
1514 def setWeightCol(self, value):
1515 """
1516 Sets the value of :py:attr:`weightCol`.
1517 """
1518 return self._set(weightCol=value)
1519
1520 @since("3.0.0")
1521 def setMinWeightFractionPerNode(self, value):
1522 """
1523 Sets the value of :py:attr:`minWeightFractionPerNode`.
1524 """
1525 return self._set(minWeightFractionPerNode=value)
1526
1527
1528 class GBTRegressionModel(
1529 JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams,
1530 JavaMLWritable, JavaMLReadable
1531 ):
1532 """
1533 Model fitted by :class:`GBTRegressor`.
1534
1535 .. versionadded:: 1.4.0
1536 """
1537
1538 @property
1539 @since("2.0.0")
1540 def featureImportances(self):
1541 """
1542 Estimate of the importance of each feature.
1543
1544 Each feature's importance is the average of its importance across all trees in the ensemble
1545 The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1546 (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1547 and follows the implementation from scikit-learn.
1548
1549 .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
1550 """
1551 return self._call_java("featureImportances")
1552
1553 @property
1554 @since("2.0.0")
1555 def trees(self):
1556 """Trees in this ensemble. Warning: These have null parent Estimators."""
1557 return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1558
1559 @since("2.4.0")
1560 def evaluateEachIteration(self, dataset, loss):
1561 """
1562 Method to compute error or loss for every iteration of gradient boosting.
1563
1564 :param dataset:
1565 Test dataset to evaluate model on, where dataset is an
1566 instance of :py:class:`pyspark.sql.DataFrame`
1567 :param loss:
1568 The loss function used to compute error.
1569 Supported options: squared, absolute
1570 """
1571 return self._call_java("evaluateEachIteration", dataset, loss)
1572
1573
1574 class _AFTSurvivalRegressionParams(_JavaPredictorParams, HasMaxIter, HasTol, HasFitIntercept,
1575 HasAggregationDepth):
1576 """
1577 Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
1578
1579 .. versionadded:: 3.0.0
1580 """
1581
1582 censorCol = Param(
1583 Params._dummy(), "censorCol",
1584 "censor column name. The value of this column could be 0 or 1. " +
1585 "If the value is 1, it means the event has occurred i.e. " +
1586 "uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
1587 quantileProbabilities = Param(
1588 Params._dummy(), "quantileProbabilities",
1589 "quantile probabilities array. Values of the quantile probabilities array " +
1590 "should be in the range (0, 1) and the array should be non-empty.",
1591 typeConverter=TypeConverters.toListFloat)
1592 quantilesCol = Param(
1593 Params._dummy(), "quantilesCol",
1594 "quantiles column name. This column will output quantiles of " +
1595 "corresponding quantileProbabilities if it is set.",
1596 typeConverter=TypeConverters.toString)
1597
1598 @since("1.6.0")
1599 def getCensorCol(self):
1600 """
1601 Gets the value of censorCol or its default value.
1602 """
1603 return self.getOrDefault(self.censorCol)
1604
1605 @since("1.6.0")
1606 def getQuantileProbabilities(self):
1607 """
1608 Gets the value of quantileProbabilities or its default value.
1609 """
1610 return self.getOrDefault(self.quantileProbabilities)
1611
1612 @since("1.6.0")
1613 def getQuantilesCol(self):
1614 """
1615 Gets the value of quantilesCol or its default value.
1616 """
1617 return self.getOrDefault(self.quantilesCol)
1618
1619
1620 @inherit_doc
1621 class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams,
1622 JavaMLWritable, JavaMLReadable):
1623 """
1624 Accelerated Failure Time (AFT) Model Survival Regression
1625
1626 Fit a parametric AFT survival regression model based on the Weibull distribution
1627 of the survival time.
1628
1629 .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
1630
1631 >>> from pyspark.ml.linalg import Vectors
1632 >>> df = spark.createDataFrame([
1633 ... (1.0, Vectors.dense(1.0), 1.0),
1634 ... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
1635 >>> aftsr = AFTSurvivalRegression()
1636 >>> aftsr.setMaxIter(10)
1637 AFTSurvivalRegression...
1638 >>> aftsr.getMaxIter()
1639 10
1640 >>> aftsr.clear(aftsr.maxIter)
1641 >>> model = aftsr.fit(df)
1642 >>> model.setFeaturesCol("features")
1643 AFTSurvivalRegressionModel...
1644 >>> model.predict(Vectors.dense(6.3))
1645 1.0
1646 >>> model.predictQuantiles(Vectors.dense(6.3))
1647 DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
1648 >>> model.transform(df).show()
1649 +-------+---------+------+----------+
1650 | label| features|censor|prediction|
1651 +-------+---------+------+----------+
1652 | 1.0| [1.0]| 1.0| 1.0|
1653 |1.0E-40|(1,[],[])| 0.0| 1.0|
1654 +-------+---------+------+----------+
1655 ...
1656 >>> aftsr_path = temp_path + "/aftsr"
1657 >>> aftsr.save(aftsr_path)
1658 >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
1659 >>> aftsr2.getMaxIter()
1660 100
1661 >>> model_path = temp_path + "/aftsr_model"
1662 >>> model.save(model_path)
1663 >>> model2 = AFTSurvivalRegressionModel.load(model_path)
1664 >>> model.coefficients == model2.coefficients
1665 True
1666 >>> model.intercept == model2.intercept
1667 True
1668 >>> model.scale == model2.scale
1669 True
1670
1671 .. versionadded:: 1.6.0
1672 """
1673
1674 @keyword_only
1675 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1676 fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
1677 quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
1678 quantilesCol=None, aggregationDepth=2):
1679 """
1680 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1681 fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
1682 quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
1683 quantilesCol=None, aggregationDepth=2)
1684 """
1685 super(AFTSurvivalRegression, self).__init__()
1686 self._java_obj = self._new_java_obj(
1687 "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
1688 self._setDefault(censorCol="censor",
1689 quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
1690 maxIter=100, tol=1E-6)
1691 kwargs = self._input_kwargs
1692 self.setParams(**kwargs)
1693
1694 @keyword_only
1695 @since("1.6.0")
1696 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1697 fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
1698 quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
1699 quantilesCol=None, aggregationDepth=2):
1700 """
1701 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1702 fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
1703 quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
1704 quantilesCol=None, aggregationDepth=2):
1705 """
1706 kwargs = self._input_kwargs
1707 return self._set(**kwargs)
1708
1709 def _create_model(self, java_model):
1710 return AFTSurvivalRegressionModel(java_model)
1711
1712 @since("1.6.0")
1713 def setCensorCol(self, value):
1714 """
1715 Sets the value of :py:attr:`censorCol`.
1716 """
1717 return self._set(censorCol=value)
1718
1719 @since("1.6.0")
1720 def setQuantileProbabilities(self, value):
1721 """
1722 Sets the value of :py:attr:`quantileProbabilities`.
1723 """
1724 return self._set(quantileProbabilities=value)
1725
1726 @since("1.6.0")
1727 def setQuantilesCol(self, value):
1728 """
1729 Sets the value of :py:attr:`quantilesCol`.
1730 """
1731 return self._set(quantilesCol=value)
1732
1733 @since("1.6.0")
1734 def setMaxIter(self, value):
1735 """
1736 Sets the value of :py:attr:`maxIter`.
1737 """
1738 return self._set(maxIter=value)
1739
1740 @since("1.6.0")
1741 def setTol(self, value):
1742 """
1743 Sets the value of :py:attr:`tol`.
1744 """
1745 return self._set(tol=value)
1746
1747 @since("1.6.0")
1748 def setFitIntercept(self, value):
1749 """
1750 Sets the value of :py:attr:`fitIntercept`.
1751 """
1752 return self._set(fitIntercept=value)
1753
1754 @since("2.1.0")
1755 def setAggregationDepth(self, value):
1756 """
1757 Sets the value of :py:attr:`aggregationDepth`.
1758 """
1759 return self._set(aggregationDepth=value)
1760
1761
1762 class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams,
1763 JavaMLWritable, JavaMLReadable):
1764 """
1765 Model fitted by :class:`AFTSurvivalRegression`.
1766
1767 .. versionadded:: 1.6.0
1768 """
1769
1770 @since("3.0.0")
1771 def setQuantileProbabilities(self, value):
1772 """
1773 Sets the value of :py:attr:`quantileProbabilities`.
1774 """
1775 return self._set(quantileProbabilities=value)
1776
1777 @since("3.0.0")
1778 def setQuantilesCol(self, value):
1779 """
1780 Sets the value of :py:attr:`quantilesCol`.
1781 """
1782 return self._set(quantilesCol=value)
1783
1784 @property
1785 @since("2.0.0")
1786 def coefficients(self):
1787 """
1788 Model coefficients.
1789 """
1790 return self._call_java("coefficients")
1791
1792 @property
1793 @since("1.6.0")
1794 def intercept(self):
1795 """
1796 Model intercept.
1797 """
1798 return self._call_java("intercept")
1799
1800 @property
1801 @since("1.6.0")
1802 def scale(self):
1803 """
1804 Model scale parameter.
1805 """
1806 return self._call_java("scale")
1807
1808 @since("2.0.0")
1809 def predictQuantiles(self, features):
1810 """
1811 Predicted Quantiles
1812 """
1813 return self._call_java("predictQuantiles", features)
1814
1815
1816 class _GeneralizedLinearRegressionParams(_JavaPredictorParams, HasFitIntercept, HasMaxIter,
1817 HasTol, HasRegParam, HasWeightCol, HasSolver,
1818 HasAggregationDepth):
1819 """
1820 Params for :py:class:`GeneralizedLinearRegression` and
1821 :py:class:`GeneralizedLinearRegressionModel`.
1822
1823 .. versionadded:: 3.0.0
1824 """
1825
1826 family = Param(Params._dummy(), "family", "The name of family which is a description of " +
1827 "the error distribution to be used in the model. Supported options: " +
1828 "gaussian (default), binomial, poisson, gamma and tweedie.",
1829 typeConverter=TypeConverters.toString)
1830 link = Param(Params._dummy(), "link", "The name of link function which provides the " +
1831 "relationship between the linear predictor and the mean of the distribution " +
1832 "function. Supported options: identity, log, inverse, logit, probit, cloglog " +
1833 "and sqrt.", typeConverter=TypeConverters.toString)
1834 linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
1835 "predictor) column name", typeConverter=TypeConverters.toString)
1836 variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " +
1837 "of the Tweedie distribution which characterizes the relationship " +
1838 "between the variance and mean of the distribution. Only applicable " +
1839 "for the Tweedie family. Supported values: 0 and [1, Inf).",
1840 typeConverter=TypeConverters.toFloat)
1841 linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
1842 "Only applicable to the Tweedie family.",
1843 typeConverter=TypeConverters.toFloat)
1844 solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
1845 "options: irls.", typeConverter=TypeConverters.toString)
1846 offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " +
1847 "or empty, we treat all instance offsets as 0.0",
1848 typeConverter=TypeConverters.toString)
1849
1850 @since("2.0.0")
1851 def getFamily(self):
1852 """
1853 Gets the value of family or its default value.
1854 """
1855 return self.getOrDefault(self.family)
1856
1857 @since("2.0.0")
1858 def getLinkPredictionCol(self):
1859 """
1860 Gets the value of linkPredictionCol or its default value.
1861 """
1862 return self.getOrDefault(self.linkPredictionCol)
1863
1864 @since("2.0.0")
1865 def getLink(self):
1866 """
1867 Gets the value of link or its default value.
1868 """
1869 return self.getOrDefault(self.link)
1870
1871 @since("2.2.0")
1872 def getVariancePower(self):
1873 """
1874 Gets the value of variancePower or its default value.
1875 """
1876 return self.getOrDefault(self.variancePower)
1877
1878 @since("2.2.0")
1879 def getLinkPower(self):
1880 """
1881 Gets the value of linkPower or its default value.
1882 """
1883 return self.getOrDefault(self.linkPower)
1884
1885 @since("2.3.0")
1886 def getOffsetCol(self):
1887 """
1888 Gets the value of offsetCol or its default value.
1889 """
1890 return self.getOrDefault(self.offsetCol)
1891
1892
1893 @inherit_doc
1894 class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams,
1895 JavaMLWritable, JavaMLReadable):
1896 """
1897 Generalized Linear Regression.
1898
1899 Fit a Generalized Linear Model specified by giving a symbolic description of the linear
1900 predictor (link function) and a description of the error distribution (family). It supports
1901 "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
1902 each family is listed below. The first link function of each family is the default one.
1903
1904 * "gaussian" -> "identity", "log", "inverse"
1905
1906 * "binomial" -> "logit", "probit", "cloglog"
1907
1908 * "poisson" -> "log", "identity", "sqrt"
1909
1910 * "gamma" -> "inverse", "identity", "log"
1911
1912 * "tweedie" -> power link function specified through "linkPower". \
1913 The default link power in the tweedie family is 1 - variancePower.
1914
1915 .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
1916
1917 >>> from pyspark.ml.linalg import Vectors
1918 >>> df = spark.createDataFrame([
1919 ... (1.0, Vectors.dense(0.0, 0.0)),
1920 ... (1.0, Vectors.dense(1.0, 2.0)),
1921 ... (2.0, Vectors.dense(0.0, 0.0)),
1922 ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
1923 >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
1924 >>> glr.setRegParam(0.1)
1925 GeneralizedLinearRegression...
1926 >>> glr.getRegParam()
1927 0.1
1928 >>> glr.clear(glr.regParam)
1929 >>> glr.setMaxIter(10)
1930 GeneralizedLinearRegression...
1931 >>> glr.getMaxIter()
1932 10
1933 >>> glr.clear(glr.maxIter)
1934 >>> model = glr.fit(df)
1935 >>> model.setFeaturesCol("features")
1936 GeneralizedLinearRegressionModel...
1937 >>> model.getMaxIter()
1938 25
1939 >>> model.getAggregationDepth()
1940 2
1941 >>> transformed = model.transform(df)
1942 >>> abs(transformed.head().prediction - 1.5) < 0.001
1943 True
1944 >>> abs(transformed.head().p - 1.5) < 0.001
1945 True
1946 >>> model.coefficients
1947 DenseVector([1.5..., -1.0...])
1948 >>> model.numFeatures
1949 2
1950 >>> abs(model.intercept - 1.5) < 0.001
1951 True
1952 >>> glr_path = temp_path + "/glr"
1953 >>> glr.save(glr_path)
1954 >>> glr2 = GeneralizedLinearRegression.load(glr_path)
1955 >>> glr.getFamily() == glr2.getFamily()
1956 True
1957 >>> model_path = temp_path + "/glr_model"
1958 >>> model.save(model_path)
1959 >>> model2 = GeneralizedLinearRegressionModel.load(model_path)
1960 >>> model.intercept == model2.intercept
1961 True
1962 >>> model.coefficients[0] == model2.coefficients[0]
1963 True
1964
1965 .. versionadded:: 2.0.0
1966 """
1967
1968 @keyword_only
1969 def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
1970 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1971 regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1972 variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2):
1973 """
1974 __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
1975 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1976 regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1977 variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
1978 """
1979 super(GeneralizedLinearRegression, self).__init__()
1980 self._java_obj = self._new_java_obj(
1981 "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
1982 self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
1983 variancePower=0.0, aggregationDepth=2)
1984 kwargs = self._input_kwargs
1985
1986 self.setParams(**kwargs)
1987
1988 @keyword_only
1989 @since("2.0.0")
1990 def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
1991 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1992 regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1993 variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2):
1994 """
1995 setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
1996 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1997 regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1998 variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
1999 Sets params for generalized linear regression.
2000 """
2001 kwargs = self._input_kwargs
2002 return self._set(**kwargs)
2003
2004 def _create_model(self, java_model):
2005 return GeneralizedLinearRegressionModel(java_model)
2006
2007 @since("2.0.0")
2008 def setFamily(self, value):
2009 """
2010 Sets the value of :py:attr:`family`.
2011 """
2012 return self._set(family=value)
2013
2014 @since("2.0.0")
2015 def setLinkPredictionCol(self, value):
2016 """
2017 Sets the value of :py:attr:`linkPredictionCol`.
2018 """
2019 return self._set(linkPredictionCol=value)
2020
2021 @since("2.0.0")
2022 def setLink(self, value):
2023 """
2024 Sets the value of :py:attr:`link`.
2025 """
2026 return self._set(link=value)
2027
2028 @since("2.2.0")
2029 def setVariancePower(self, value):
2030 """
2031 Sets the value of :py:attr:`variancePower`.
2032 """
2033 return self._set(variancePower=value)
2034
2035 @since("2.2.0")
2036 def setLinkPower(self, value):
2037 """
2038 Sets the value of :py:attr:`linkPower`.
2039 """
2040 return self._set(linkPower=value)
2041
2042 @since("2.3.0")
2043 def setOffsetCol(self, value):
2044 """
2045 Sets the value of :py:attr:`offsetCol`.
2046 """
2047 return self._set(offsetCol=value)
2048
2049 @since("2.0.0")
2050 def setMaxIter(self, value):
2051 """
2052 Sets the value of :py:attr:`maxIter`.
2053 """
2054 return self._set(maxIter=value)
2055
2056 @since("2.0.0")
2057 def setRegParam(self, value):
2058 """
2059 Sets the value of :py:attr:`regParam`.
2060 """
2061 return self._set(regParam=value)
2062
2063 @since("2.0.0")
2064 def setTol(self, value):
2065 """
2066 Sets the value of :py:attr:`tol`.
2067 """
2068 return self._set(tol=value)
2069
2070 @since("2.0.0")
2071 def setFitIntercept(self, value):
2072 """
2073 Sets the value of :py:attr:`fitIntercept`.
2074 """
2075 return self._set(fitIntercept=value)
2076
2077 @since("2.0.0")
2078 def setWeightCol(self, value):
2079 """
2080 Sets the value of :py:attr:`weightCol`.
2081 """
2082 return self._set(weightCol=value)
2083
2084 @since("2.0.0")
2085 def setSolver(self, value):
2086 """
2087 Sets the value of :py:attr:`solver`.
2088 """
2089 return self._set(solver=value)
2090
2091 @since("3.0.0")
2092 def setAggregationDepth(self, value):
2093 """
2094 Sets the value of :py:attr:`aggregationDepth`.
2095 """
2096 return self._set(aggregationDepth=value)
2097
2098
2099 class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams,
2100 JavaMLWritable, JavaMLReadable, HasTrainingSummary):
2101 """
2102 Model fitted by :class:`GeneralizedLinearRegression`.
2103
2104 .. versionadded:: 2.0.0
2105 """
2106
2107 @since("3.0.0")
2108 def setLinkPredictionCol(self, value):
2109 """
2110 Sets the value of :py:attr:`linkPredictionCol`.
2111 """
2112 return self._set(linkPredictionCol=value)
2113
2114 @property
2115 @since("2.0.0")
2116 def coefficients(self):
2117 """
2118 Model coefficients.
2119 """
2120 return self._call_java("coefficients")
2121
2122 @property
2123 @since("2.0.0")
2124 def intercept(self):
2125 """
2126 Model intercept.
2127 """
2128 return self._call_java("intercept")
2129
2130 @property
2131 @since("2.0.0")
2132 def summary(self):
2133 """
2134 Gets summary (e.g. residuals, deviance, pValues) of model on
2135 training set. An exception is thrown if
2136 `trainingSummary is None`.
2137 """
2138 if self.hasSummary:
2139 return GeneralizedLinearRegressionTrainingSummary(
2140 super(GeneralizedLinearRegressionModel, self).summary)
2141 else:
2142 raise RuntimeError("No training summary available for this %s" %
2143 self.__class__.__name__)
2144
2145 @since("2.0.0")
2146 def evaluate(self, dataset):
2147 """
2148 Evaluates the model on a test dataset.
2149
2150 :param dataset:
2151 Test dataset to evaluate model on, where dataset is an
2152 instance of :py:class:`pyspark.sql.DataFrame`
2153 """
2154 if not isinstance(dataset, DataFrame):
2155 raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
2156 java_glr_summary = self._call_java("evaluate", dataset)
2157 return GeneralizedLinearRegressionSummary(java_glr_summary)
2158
2159
2160 class GeneralizedLinearRegressionSummary(JavaWrapper):
2161 """
2162 Generalized linear regression results evaluated on a dataset.
2163
2164 .. versionadded:: 2.0.0
2165 """
2166
2167 @property
2168 @since("2.0.0")
2169 def predictions(self):
2170 """
2171 Predictions output by the model's `transform` method.
2172 """
2173 return self._call_java("predictions")
2174
2175 @property
2176 @since("2.0.0")
2177 def predictionCol(self):
2178 """
2179 Field in :py:attr:`predictions` which gives the predicted value of each instance.
2180 This is set to a new column name if the original model's `predictionCol` is not set.
2181 """
2182 return self._call_java("predictionCol")
2183
2184 @property
2185 @since("2.2.0")
2186 def numInstances(self):
2187 """
2188 Number of instances in DataFrame predictions.
2189 """
2190 return self._call_java("numInstances")
2191
2192 @property
2193 @since("2.0.0")
2194 def rank(self):
2195 """
2196 The numeric rank of the fitted linear model.
2197 """
2198 return self._call_java("rank")
2199
2200 @property
2201 @since("2.0.0")
2202 def degreesOfFreedom(self):
2203 """
2204 Degrees of freedom.
2205 """
2206 return self._call_java("degreesOfFreedom")
2207
2208 @property
2209 @since("2.0.0")
2210 def residualDegreeOfFreedom(self):
2211 """
2212 The residual degrees of freedom.
2213 """
2214 return self._call_java("residualDegreeOfFreedom")
2215
2216 @property
2217 @since("2.0.0")
2218 def residualDegreeOfFreedomNull(self):
2219 """
2220 The residual degrees of freedom for the null model.
2221 """
2222 return self._call_java("residualDegreeOfFreedomNull")
2223
2224 @since("2.0.0")
2225 def residuals(self, residualsType="deviance"):
2226 """
2227 Get the residuals of the fitted model by type.
2228
2229 :param residualsType: The type of residuals which should be returned.
2230 Supported options: deviance (default), pearson, working, and response.
2231 """
2232 return self._call_java("residuals", residualsType)
2233
2234 @property
2235 @since("2.0.0")
2236 def nullDeviance(self):
2237 """
2238 The deviance for the null model.
2239 """
2240 return self._call_java("nullDeviance")
2241
2242 @property
2243 @since("2.0.0")
2244 def deviance(self):
2245 """
2246 The deviance for the fitted model.
2247 """
2248 return self._call_java("deviance")
2249
2250 @property
2251 @since("2.0.0")
2252 def dispersion(self):
2253 """
2254 The dispersion of the fitted model.
2255 It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise
2256 estimated by the residual Pearson's Chi-Squared statistic (which is defined as
2257 sum of the squares of the Pearson residuals) divided by the residual degrees of freedom.
2258 """
2259 return self._call_java("dispersion")
2260
2261 @property
2262 @since("2.0.0")
2263 def aic(self):
2264 """
2265 Akaike's "An Information Criterion"(AIC) for the fitted model.
2266 """
2267 return self._call_java("aic")
2268
2269
2270 @inherit_doc
2271 class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary):
2272 """
2273 Generalized linear regression training results.
2274
2275 .. versionadded:: 2.0.0
2276 """
2277
2278 @property
2279 @since("2.0.0")
2280 def numIterations(self):
2281 """
2282 Number of training iterations.
2283 """
2284 return self._call_java("numIterations")
2285
2286 @property
2287 @since("2.0.0")
2288 def solver(self):
2289 """
2290 The numeric solver used for training.
2291 """
2292 return self._call_java("solver")
2293
2294 @property
2295 @since("2.0.0")
2296 def coefficientStandardErrors(self):
2297 """
2298 Standard error of estimated coefficients and intercept.
2299
2300 If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2301 then the last element returned corresponds to the intercept.
2302 """
2303 return self._call_java("coefficientStandardErrors")
2304
2305 @property
2306 @since("2.0.0")
2307 def tValues(self):
2308 """
2309 T-statistic of estimated coefficients and intercept.
2310
2311 If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2312 then the last element returned corresponds to the intercept.
2313 """
2314 return self._call_java("tValues")
2315
2316 @property
2317 @since("2.0.0")
2318 def pValues(self):
2319 """
2320 Two-sided p-value of estimated coefficients and intercept.
2321
2322 If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2323 then the last element returned corresponds to the intercept.
2324 """
2325 return self._call_java("pValues")
2326
2327 def __repr__(self):
2328 return self._call_java("toString")
2329
2330
2331 class _FactorizationMachinesParams(_JavaPredictorParams, HasMaxIter, HasStepSize, HasTol,
2332 HasSolver, HasSeed, HasFitIntercept, HasRegParam):
2333 """
2334 Params for :py:class:`FMRegressor`, :py:class:`FMRegressionModel`, :py:class:`FMClassifier`
2335 and :py:class:`FMClassifierModel`.
2336
2337 .. versionadded:: 3.0.0
2338 """
2339
2340 factorSize = Param(Params._dummy(), "factorSize", "Dimensionality of the factor vectors, " +
2341 "which are used to get pairwise interactions between variables",
2342 typeConverter=TypeConverters.toInt)
2343
2344 fitLinear = Param(Params._dummy(), "fitLinear", "whether to fit linear term (aka 1-way term)",
2345 typeConverter=TypeConverters.toBoolean)
2346
2347 miniBatchFraction = Param(Params._dummy(), "miniBatchFraction", "fraction of the input data " +
2348 "set that should be used for one iteration of gradient descent",
2349 typeConverter=TypeConverters.toFloat)
2350
2351 initStd = Param(Params._dummy(), "initStd", "standard deviation of initial coefficients",
2352 typeConverter=TypeConverters.toFloat)
2353
2354 solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
2355 "options: gd, adamW. (Default adamW)", typeConverter=TypeConverters.toString)
2356
2357 @since("3.0.0")
2358 def getFactorSize(self):
2359 """
2360 Gets the value of factorSize or its default value.
2361 """
2362 return self.getOrDefault(self.factorSize)
2363
2364 @since("3.0.0")
2365 def getFitLinear(self):
2366 """
2367 Gets the value of fitLinear or its default value.
2368 """
2369 return self.getOrDefault(self.fitLinear)
2370
2371 @since("3.0.0")
2372 def getMiniBatchFraction(self):
2373 """
2374 Gets the value of miniBatchFraction or its default value.
2375 """
2376 return self.getOrDefault(self.miniBatchFraction)
2377
2378 @since("3.0.0")
2379 def getInitStd(self):
2380 """
2381 Gets the value of initStd or its default value.
2382 """
2383 return self.getOrDefault(self.initStd)
2384
2385
2386 @inherit_doc
2387 class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
2388 """
2389 Factorization Machines learning algorithm for regression.
2390
2391 solver Supports:
2392
2393 * gd (normal mini-batch gradient descent)
2394 * adamW (default)
2395
2396 >>> from pyspark.ml.linalg import Vectors
2397 >>> from pyspark.ml.regression import FMRegressor
2398 >>> df = spark.createDataFrame([
2399 ... (2.0, Vectors.dense(2.0)),
2400 ... (1.0, Vectors.dense(1.0)),
2401 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
2402 >>>
2403 >>> fm = FMRegressor(factorSize=2)
2404 >>> fm.setSeed(16)
2405 FMRegressor...
2406 >>> model = fm.fit(df)
2407 >>> model.getMaxIter()
2408 100
2409 >>> test0 = spark.createDataFrame([
2410 ... (Vectors.dense(-2.0),),
2411 ... (Vectors.dense(0.5),),
2412 ... (Vectors.dense(1.0),),
2413 ... (Vectors.dense(4.0),)], ["features"])
2414 >>> model.transform(test0).show(10, False)
2415 +--------+-------------------+
2416 |features|prediction |
2417 +--------+-------------------+
2418 |[-2.0] |-1.9989237712341565|
2419 |[0.5] |0.4956682219523814 |
2420 |[1.0] |0.994586620589689 |
2421 |[4.0] |3.9880970124135344 |
2422 +--------+-------------------+
2423 ...
2424 >>> model.intercept
2425 -0.0032501766849261557
2426 >>> model.linear
2427 DenseVector([0.9978])
2428 >>> model.factors
2429 DenseMatrix(1, 2, [0.0173, 0.0021], 1)
2430
2431 .. versionadded:: 3.0.0
2432 """
2433
2434 @keyword_only
2435 def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2436 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2437 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2438 tol=1e-6, solver="adamW", seed=None):
2439 """
2440 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2441 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2442 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2443 tol=1e-6, solver="adamW", seed=None)
2444 """
2445 super(FMRegressor, self).__init__()
2446 self._java_obj = self._new_java_obj(
2447 "org.apache.spark.ml.regression.FMRegressor", self.uid)
2448 self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2449 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2450 tol=1e-6, solver="adamW")
2451 kwargs = self._input_kwargs
2452 self.setParams(**kwargs)
2453
2454 @keyword_only
2455 @since("3.0.0")
2456 def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2457 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2458 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2459 tol=1e-6, solver="adamW", seed=None):
2460 """
2461 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2462 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2463 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2464 tol=1e-6, solver="adamW", seed=None)
2465 Sets Params for FMRegressor.
2466 """
2467 kwargs = self._input_kwargs
2468 return self._set(**kwargs)
2469
2470 def _create_model(self, java_model):
2471 return FMRegressionModel(java_model)
2472
2473 @since("3.0.0")
2474 def setFactorSize(self, value):
2475 """
2476 Sets the value of :py:attr:`factorSize`.
2477 """
2478 return self._set(factorSize=value)
2479
2480 @since("3.0.0")
2481 def setFitLinear(self, value):
2482 """
2483 Sets the value of :py:attr:`fitLinear`.
2484 """
2485 return self._set(fitLinear=value)
2486
2487 @since("3.0.0")
2488 def setMiniBatchFraction(self, value):
2489 """
2490 Sets the value of :py:attr:`miniBatchFraction`.
2491 """
2492 return self._set(miniBatchFraction=value)
2493
2494 @since("3.0.0")
2495 def setInitStd(self, value):
2496 """
2497 Sets the value of :py:attr:`initStd`.
2498 """
2499 return self._set(initStd=value)
2500
2501 @since("3.0.0")
2502 def setMaxIter(self, value):
2503 """
2504 Sets the value of :py:attr:`maxIter`.
2505 """
2506 return self._set(maxIter=value)
2507
2508 @since("3.0.0")
2509 def setStepSize(self, value):
2510 """
2511 Sets the value of :py:attr:`stepSize`.
2512 """
2513 return self._set(stepSize=value)
2514
2515 @since("3.0.0")
2516 def setTol(self, value):
2517 """
2518 Sets the value of :py:attr:`tol`.
2519 """
2520 return self._set(tol=value)
2521
2522 @since("3.0.0")
2523 def setSolver(self, value):
2524 """
2525 Sets the value of :py:attr:`solver`.
2526 """
2527 return self._set(solver=value)
2528
2529 @since("3.0.0")
2530 def setSeed(self, value):
2531 """
2532 Sets the value of :py:attr:`seed`.
2533 """
2534 return self._set(seed=value)
2535
2536 @since("3.0.0")
2537 def setFitIntercept(self, value):
2538 """
2539 Sets the value of :py:attr:`fitIntercept`.
2540 """
2541 return self._set(fitIntercept=value)
2542
2543 @since("3.0.0")
2544 def setRegParam(self, value):
2545 """
2546 Sets the value of :py:attr:`regParam`.
2547 """
2548 return self._set(regParam=value)
2549
2550
2551 class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable,
2552 JavaMLReadable):
2553 """
2554 Model fitted by :class:`FMRegressor`.
2555
2556 .. versionadded:: 3.0.0
2557 """
2558
2559 @property
2560 @since("3.0.0")
2561 def intercept(self):
2562 """
2563 Model intercept.
2564 """
2565 return self._call_java("intercept")
2566
2567 @property
2568 @since("3.0.0")
2569 def linear(self):
2570 """
2571 Model linear term.
2572 """
2573 return self._call_java("linear")
2574
2575 @property
2576 @since("3.0.0")
2577 def factors(self):
2578 """
2579 Model factor term.
2580 """
2581 return self._call_java("factors")
2582
2583
2584 if __name__ == "__main__":
2585 import doctest
2586 import pyspark.ml.regression
2587 from pyspark.sql import SparkSession
2588 globs = pyspark.ml.regression.__dict__.copy()
2589
2590
2591 spark = SparkSession.builder\
2592 .master("local[2]")\
2593 .appName("ml.regression tests")\
2594 .getOrCreate()
2595 sc = spark.sparkContext
2596 globs['sc'] = sc
2597 globs['spark'] = spark
2598 import tempfile
2599 temp_path = tempfile.mkdtemp()
2600 globs['temp_path'] = temp_path
2601 try:
2602 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
2603 spark.stop()
2604 finally:
2605 from shutil import rmtree
2606 try:
2607 rmtree(temp_path)
2608 except OSError:
2609 pass
2610 if failure_count:
2611 sys.exit(-1)