Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 
0020 from pyspark import since, keyword_only
0021 from pyspark.ml.param.shared import *
0022 from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
0023     _TreeEnsembleModel, _TreeEnsembleParams, _RandomForestParams, _GBTParams, \
0024     _HasVarianceImpurity, _TreeRegressorParams
0025 from pyspark.ml.util import *
0026 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
0027     JavaPredictor, JavaPredictionModel, _JavaPredictorParams, JavaWrapper
0028 from pyspark.ml.common import inherit_doc
0029 from pyspark.sql import DataFrame
0030 
0031 
0032 __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
0033            'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
0034            'GBTRegressor', 'GBTRegressionModel',
0035            'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel',
0036            'GeneralizedLinearRegressionSummary', 'GeneralizedLinearRegressionTrainingSummary',
0037            'IsotonicRegression', 'IsotonicRegressionModel',
0038            'LinearRegression', 'LinearRegressionModel',
0039            'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
0040            'RandomForestRegressor', 'RandomForestRegressionModel',
0041            'FMRegressor', 'FMRegressionModel']
0042 
0043 
0044 class JavaRegressor(JavaPredictor, _JavaPredictorParams):
0045     """
0046     Java Regressor for regression tasks.
0047 
0048     .. versionadded:: 3.0.0
0049     """
0050     pass
0051 
0052 
0053 class JavaRegressionModel(JavaPredictionModel, _JavaPredictorParams):
0054     """
0055     Java Model produced by a ``_JavaRegressor``.
0056     To be mixed in with :class:`pyspark.ml.JavaModel`
0057 
0058     .. versionadded:: 3.0.0
0059     """
0060     pass
0061 
0062 
0063 class _LinearRegressionParams(_JavaPredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
0064                               HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver,
0065                               HasAggregationDepth, HasLoss):
0066     """
0067     Params for :py:class:`LinearRegression` and :py:class:`LinearRegressionModel`.
0068 
0069     .. versionadded:: 3.0.0
0070     """
0071 
0072     solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
0073                    "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
0074 
0075     loss = Param(Params._dummy(), "loss", "The loss function to be optimized. Supported " +
0076                  "options: squaredError, huber.", typeConverter=TypeConverters.toString)
0077 
0078     epsilon = Param(Params._dummy(), "epsilon", "The shape parameter to control the amount of " +
0079                     "robustness. Must be > 1.0. Only valid when loss is huber",
0080                     typeConverter=TypeConverters.toFloat)
0081 
0082     @since("2.3.0")
0083     def getEpsilon(self):
0084         """
0085         Gets the value of epsilon or its default value.
0086         """
0087         return self.getOrDefault(self.epsilon)
0088 
0089 
0090 @inherit_doc
0091 class LinearRegression(JavaRegressor, _LinearRegressionParams, JavaMLWritable, JavaMLReadable):
0092     """
0093     Linear regression.
0094 
0095     The learning objective is to minimize the specified loss function, with regularization.
0096     This supports two kinds of loss:
0097 
0098     * squaredError (a.k.a squared loss)
0099     * huber (a hybrid of squared error for relatively small errors and absolute error for \
0100     relatively large ones, and we estimate the scale parameter from training data)
0101 
0102     This supports multiple types of regularization:
0103 
0104     * none (a.k.a. ordinary least squares)
0105     * L2 (ridge regression)
0106     * L1 (Lasso)
0107     * L2 + L1 (elastic net)
0108 
0109     Note: Fitting with huber loss only supports none and L2 regularization.
0110 
0111     >>> from pyspark.ml.linalg import Vectors
0112     >>> df = spark.createDataFrame([
0113     ...     (1.0, 2.0, Vectors.dense(1.0)),
0114     ...     (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
0115     >>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight")
0116     >>> lr.setMaxIter(5)
0117     LinearRegression...
0118     >>> lr.getMaxIter()
0119     5
0120     >>> lr.setRegParam(0.1)
0121     LinearRegression...
0122     >>> lr.getRegParam()
0123     0.1
0124     >>> lr.setRegParam(0.0)
0125     LinearRegression...
0126     >>> model = lr.fit(df)
0127     >>> model.setFeaturesCol("features")
0128     LinearRegressionModel...
0129     >>> model.setPredictionCol("newPrediction")
0130     LinearRegressionModel...
0131     >>> model.getMaxIter()
0132     5
0133     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
0134     >>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001
0135     True
0136     >>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001
0137     True
0138     >>> abs(model.coefficients[0] - 1.0) < 0.001
0139     True
0140     >>> abs(model.intercept - 0.0) < 0.001
0141     True
0142     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
0143     >>> abs(model.transform(test1).head().newPrediction - 1.0) < 0.001
0144     True
0145     >>> lr.setParams("vector")
0146     Traceback (most recent call last):
0147         ...
0148     TypeError: Method setParams forces keyword arguments.
0149     >>> lr_path = temp_path + "/lr"
0150     >>> lr.save(lr_path)
0151     >>> lr2 = LinearRegression.load(lr_path)
0152     >>> lr2.getMaxIter()
0153     5
0154     >>> model_path = temp_path + "/lr_model"
0155     >>> model.save(model_path)
0156     >>> model2 = LinearRegressionModel.load(model_path)
0157     >>> model.coefficients[0] == model2.coefficients[0]
0158     True
0159     >>> model.intercept == model2.intercept
0160     True
0161     >>> model.numFeatures
0162     1
0163     >>> model.write().format("pmml").save(model_path + "_2")
0164 
0165     .. versionadded:: 1.4.0
0166     """
0167 
0168     @keyword_only
0169     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0170                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0171                  standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
0172                  loss="squaredError", epsilon=1.35):
0173         """
0174         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0175                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0176                  standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
0177                  loss="squaredError", epsilon=1.35)
0178         """
0179         super(LinearRegression, self).__init__()
0180         self._java_obj = self._new_java_obj(
0181             "org.apache.spark.ml.regression.LinearRegression", self.uid)
0182         self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35)
0183         kwargs = self._input_kwargs
0184         self.setParams(**kwargs)
0185 
0186     @keyword_only
0187     @since("1.4.0")
0188     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0189                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0190                   standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
0191                   loss="squaredError", epsilon=1.35):
0192         """
0193         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0194                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0195                   standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
0196                   loss="squaredError", epsilon=1.35)
0197         Sets params for linear regression.
0198         """
0199         kwargs = self._input_kwargs
0200         return self._set(**kwargs)
0201 
0202     def _create_model(self, java_model):
0203         return LinearRegressionModel(java_model)
0204 
0205     @since("2.3.0")
0206     def setEpsilon(self, value):
0207         """
0208         Sets the value of :py:attr:`epsilon`.
0209         """
0210         return self._set(epsilon=value)
0211 
0212     def setMaxIter(self, value):
0213         """
0214         Sets the value of :py:attr:`maxIter`.
0215         """
0216         return self._set(maxIter=value)
0217 
0218     def setRegParam(self, value):
0219         """
0220         Sets the value of :py:attr:`regParam`.
0221         """
0222         return self._set(regParam=value)
0223 
0224     def setTol(self, value):
0225         """
0226         Sets the value of :py:attr:`tol`.
0227         """
0228         return self._set(tol=value)
0229 
0230     def setElasticNetParam(self, value):
0231         """
0232         Sets the value of :py:attr:`elasticNetParam`.
0233         """
0234         return self._set(elasticNetParam=value)
0235 
0236     def setFitIntercept(self, value):
0237         """
0238         Sets the value of :py:attr:`fitIntercept`.
0239         """
0240         return self._set(fitIntercept=value)
0241 
0242     def setStandardization(self, value):
0243         """
0244         Sets the value of :py:attr:`standardization`.
0245         """
0246         return self._set(standardization=value)
0247 
0248     def setWeightCol(self, value):
0249         """
0250         Sets the value of :py:attr:`weightCol`.
0251         """
0252         return self._set(weightCol=value)
0253 
0254     def setSolver(self, value):
0255         """
0256         Sets the value of :py:attr:`solver`.
0257         """
0258         return self._set(solver=value)
0259 
0260     def setAggregationDepth(self, value):
0261         """
0262         Sets the value of :py:attr:`aggregationDepth`.
0263         """
0264         return self._set(aggregationDepth=value)
0265 
0266     def setLoss(self, value):
0267         """
0268         Sets the value of :py:attr:`loss`.
0269         """
0270         return self._set(lossType=value)
0271 
0272 
0273 class LinearRegressionModel(JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
0274                             JavaMLReadable, HasTrainingSummary):
0275     """
0276     Model fitted by :class:`LinearRegression`.
0277 
0278     .. versionadded:: 1.4.0
0279     """
0280 
0281     @property
0282     @since("2.0.0")
0283     def coefficients(self):
0284         """
0285         Model coefficients.
0286         """
0287         return self._call_java("coefficients")
0288 
0289     @property
0290     @since("1.4.0")
0291     def intercept(self):
0292         """
0293         Model intercept.
0294         """
0295         return self._call_java("intercept")
0296 
0297     @property
0298     @since("2.3.0")
0299     def scale(self):
0300         r"""
0301         The value by which :math:`\|y - X'w\|` is scaled down when loss is "huber", otherwise 1.0.
0302         """
0303         return self._call_java("scale")
0304 
0305     @property
0306     @since("2.0.0")
0307     def summary(self):
0308         """
0309         Gets summary (e.g. residuals, mse, r-squared ) of model on
0310         training set. An exception is thrown if
0311         `trainingSummary is None`.
0312         """
0313         if self.hasSummary:
0314             return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
0315         else:
0316             raise RuntimeError("No training summary available for this %s" %
0317                                self.__class__.__name__)
0318 
0319     @since("2.0.0")
0320     def evaluate(self, dataset):
0321         """
0322         Evaluates the model on a test dataset.
0323 
0324         :param dataset:
0325           Test dataset to evaluate model on, where dataset is an
0326           instance of :py:class:`pyspark.sql.DataFrame`
0327         """
0328         if not isinstance(dataset, DataFrame):
0329             raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
0330         java_lr_summary = self._call_java("evaluate", dataset)
0331         return LinearRegressionSummary(java_lr_summary)
0332 
0333 
0334 class LinearRegressionSummary(JavaWrapper):
0335     """
0336     Linear regression results evaluated on a dataset.
0337 
0338     .. versionadded:: 2.0.0
0339     """
0340 
0341     @property
0342     @since("2.0.0")
0343     def predictions(self):
0344         """
0345         Dataframe outputted by the model's `transform` method.
0346         """
0347         return self._call_java("predictions")
0348 
0349     @property
0350     @since("2.0.0")
0351     def predictionCol(self):
0352         """
0353         Field in "predictions" which gives the predicted value of
0354         the label at each instance.
0355         """
0356         return self._call_java("predictionCol")
0357 
0358     @property
0359     @since("2.0.0")
0360     def labelCol(self):
0361         """
0362         Field in "predictions" which gives the true label of each
0363         instance.
0364         """
0365         return self._call_java("labelCol")
0366 
0367     @property
0368     @since("2.0.0")
0369     def featuresCol(self):
0370         """
0371         Field in "predictions" which gives the features of each instance
0372         as a vector.
0373         """
0374         return self._call_java("featuresCol")
0375 
0376     @property
0377     @since("2.0.0")
0378     def explainedVariance(self):
0379         r"""
0380         Returns the explained variance regression score.
0381         explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}`
0382 
0383         .. seealso:: `Wikipedia explain variation
0384             <http://en.wikipedia.org/wiki/Explained_variation>`_
0385 
0386         .. note:: This ignores instance weights (setting all to 1.0) from
0387             `LinearRegression.weightCol`. This will change in later Spark
0388             versions.
0389         """
0390         return self._call_java("explainedVariance")
0391 
0392     @property
0393     @since("2.0.0")
0394     def meanAbsoluteError(self):
0395         """
0396         Returns the mean absolute error, which is a risk function
0397         corresponding to the expected value of the absolute error
0398         loss or l1-norm loss.
0399 
0400         .. note:: This ignores instance weights (setting all to 1.0) from
0401             `LinearRegression.weightCol`. This will change in later Spark
0402             versions.
0403         """
0404         return self._call_java("meanAbsoluteError")
0405 
0406     @property
0407     @since("2.0.0")
0408     def meanSquaredError(self):
0409         """
0410         Returns the mean squared error, which is a risk function
0411         corresponding to the expected value of the squared error
0412         loss or quadratic loss.
0413 
0414         .. note:: This ignores instance weights (setting all to 1.0) from
0415             `LinearRegression.weightCol`. This will change in later Spark
0416             versions.
0417         """
0418         return self._call_java("meanSquaredError")
0419 
0420     @property
0421     @since("2.0.0")
0422     def rootMeanSquaredError(self):
0423         """
0424         Returns the root mean squared error, which is defined as the
0425         square root of the mean squared error.
0426 
0427         .. note:: This ignores instance weights (setting all to 1.0) from
0428             `LinearRegression.weightCol`. This will change in later Spark
0429             versions.
0430         """
0431         return self._call_java("rootMeanSquaredError")
0432 
0433     @property
0434     @since("2.0.0")
0435     def r2(self):
0436         """
0437         Returns R^2, the coefficient of determination.
0438 
0439         .. seealso:: `Wikipedia coefficient of determination
0440             <http://en.wikipedia.org/wiki/Coefficient_of_determination>`_
0441 
0442         .. note:: This ignores instance weights (setting all to 1.0) from
0443             `LinearRegression.weightCol`. This will change in later Spark
0444             versions.
0445         """
0446         return self._call_java("r2")
0447 
0448     @property
0449     @since("2.4.0")
0450     def r2adj(self):
0451         """
0452         Returns Adjusted R^2, the adjusted coefficient of determination.
0453 
0454         .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2
0455             <https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2>`_
0456 
0457         .. note:: This ignores instance weights (setting all to 1.0) from
0458             `LinearRegression.weightCol`. This will change in later Spark versions.
0459         """
0460         return self._call_java("r2adj")
0461 
0462     @property
0463     @since("2.0.0")
0464     def residuals(self):
0465         """
0466         Residuals (label - predicted value)
0467         """
0468         return self._call_java("residuals")
0469 
0470     @property
0471     @since("2.0.0")
0472     def numInstances(self):
0473         """
0474         Number of instances in DataFrame predictions
0475         """
0476         return self._call_java("numInstances")
0477 
0478     @property
0479     @since("2.2.0")
0480     def degreesOfFreedom(self):
0481         """
0482         Degrees of freedom.
0483         """
0484         return self._call_java("degreesOfFreedom")
0485 
0486     @property
0487     @since("2.0.0")
0488     def devianceResiduals(self):
0489         """
0490         The weighted residuals, the usual residuals rescaled by the
0491         square root of the instance weights.
0492         """
0493         return self._call_java("devianceResiduals")
0494 
0495     @property
0496     @since("2.0.0")
0497     def coefficientStandardErrors(self):
0498         """
0499         Standard error of estimated coefficients and intercept.
0500         This value is only available when using the "normal" solver.
0501 
0502         If :py:attr:`LinearRegression.fitIntercept` is set to True,
0503         then the last element returned corresponds to the intercept.
0504 
0505         .. seealso:: :py:attr:`LinearRegression.solver`
0506         """
0507         return self._call_java("coefficientStandardErrors")
0508 
0509     @property
0510     @since("2.0.0")
0511     def tValues(self):
0512         """
0513         T-statistic of estimated coefficients and intercept.
0514         This value is only available when using the "normal" solver.
0515 
0516         If :py:attr:`LinearRegression.fitIntercept` is set to True,
0517         then the last element returned corresponds to the intercept.
0518 
0519         .. seealso:: :py:attr:`LinearRegression.solver`
0520         """
0521         return self._call_java("tValues")
0522 
0523     @property
0524     @since("2.0.0")
0525     def pValues(self):
0526         """
0527         Two-sided p-value of estimated coefficients and intercept.
0528         This value is only available when using the "normal" solver.
0529 
0530         If :py:attr:`LinearRegression.fitIntercept` is set to True,
0531         then the last element returned corresponds to the intercept.
0532 
0533         .. seealso:: :py:attr:`LinearRegression.solver`
0534         """
0535         return self._call_java("pValues")
0536 
0537 
0538 @inherit_doc
0539 class LinearRegressionTrainingSummary(LinearRegressionSummary):
0540     """
0541     Linear regression training results. Currently, the training summary ignores the
0542     training weights except for the objective trace.
0543 
0544     .. versionadded:: 2.0.0
0545     """
0546 
0547     @property
0548     @since("2.0.0")
0549     def objectiveHistory(self):
0550         """
0551         Objective function (scaled loss + regularization) at each
0552         iteration.
0553         This value is only available when using the "l-bfgs" solver.
0554 
0555         .. seealso:: :py:attr:`LinearRegression.solver`
0556         """
0557         return self._call_java("objectiveHistory")
0558 
0559     @property
0560     @since("2.0.0")
0561     def totalIterations(self):
0562         """
0563         Number of training iterations until termination.
0564         This value is only available when using the "l-bfgs" solver.
0565 
0566         .. seealso:: :py:attr:`LinearRegression.solver`
0567         """
0568         return self._call_java("totalIterations")
0569 
0570 
0571 class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol):
0572     """
0573     Params for :py:class:`IsotonicRegression` and :py:class:`IsotonicRegressionModel`.
0574 
0575     .. versionadded:: 3.0.0
0576     """
0577 
0578     isotonic = Param(
0579         Params._dummy(), "isotonic",
0580         "whether the output sequence should be isotonic/increasing (true) or" +
0581         "antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean)
0582     featureIndex = Param(
0583         Params._dummy(), "featureIndex",
0584         "The index of the feature if featuresCol is a vector column, no effect otherwise.",
0585         typeConverter=TypeConverters.toInt)
0586 
0587     def getIsotonic(self):
0588         """
0589         Gets the value of isotonic or its default value.
0590         """
0591         return self.getOrDefault(self.isotonic)
0592 
0593     def getFeatureIndex(self):
0594         """
0595         Gets the value of featureIndex or its default value.
0596         """
0597         return self.getOrDefault(self.featureIndex)
0598 
0599 
0600 @inherit_doc
0601 class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol,
0602                          JavaMLWritable, JavaMLReadable):
0603     """
0604     Currently implemented using parallelized pool adjacent violators algorithm.
0605     Only univariate (single feature) algorithm supported.
0606 
0607     >>> from pyspark.ml.linalg import Vectors
0608     >>> df = spark.createDataFrame([
0609     ...     (1.0, Vectors.dense(1.0)),
0610     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
0611     >>> ir = IsotonicRegression()
0612     >>> model = ir.fit(df)
0613     >>> model.setFeaturesCol("features")
0614     IsotonicRegressionModel...
0615     >>> model.numFeatures
0616     1
0617     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
0618     >>> model.transform(test0).head().prediction
0619     0.0
0620     >>> model.predict(test0.head().features[model.getFeatureIndex()])
0621     0.0
0622     >>> model.boundaries
0623     DenseVector([0.0, 1.0])
0624     >>> ir_path = temp_path + "/ir"
0625     >>> ir.save(ir_path)
0626     >>> ir2 = IsotonicRegression.load(ir_path)
0627     >>> ir2.getIsotonic()
0628     True
0629     >>> model_path = temp_path + "/ir_model"
0630     >>> model.save(model_path)
0631     >>> model2 = IsotonicRegressionModel.load(model_path)
0632     >>> model.boundaries == model2.boundaries
0633     True
0634     >>> model.predictions == model2.predictions
0635     True
0636 
0637     .. versionadded:: 1.6.0
0638     """
0639     @keyword_only
0640     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0641                  weightCol=None, isotonic=True, featureIndex=0):
0642         """
0643         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0644                  weightCol=None, isotonic=True, featureIndex=0):
0645         """
0646         super(IsotonicRegression, self).__init__()
0647         self._java_obj = self._new_java_obj(
0648             "org.apache.spark.ml.regression.IsotonicRegression", self.uid)
0649         self._setDefault(isotonic=True, featureIndex=0)
0650         kwargs = self._input_kwargs
0651         self.setParams(**kwargs)
0652 
0653     @keyword_only
0654     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0655                   weightCol=None, isotonic=True, featureIndex=0):
0656         """
0657         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0658                  weightCol=None, isotonic=True, featureIndex=0):
0659         Set the params for IsotonicRegression.
0660         """
0661         kwargs = self._input_kwargs
0662         return self._set(**kwargs)
0663 
0664     def _create_model(self, java_model):
0665         return IsotonicRegressionModel(java_model)
0666 
0667     def setIsotonic(self, value):
0668         """
0669         Sets the value of :py:attr:`isotonic`.
0670         """
0671         return self._set(isotonic=value)
0672 
0673     def setFeatureIndex(self, value):
0674         """
0675         Sets the value of :py:attr:`featureIndex`.
0676         """
0677         return self._set(featureIndex=value)
0678 
0679     @since("1.6.0")
0680     def setFeaturesCol(self, value):
0681         """
0682         Sets the value of :py:attr:`featuresCol`.
0683         """
0684         return self._set(featuresCol=value)
0685 
0686     @since("1.6.0")
0687     def setPredictionCol(self, value):
0688         """
0689         Sets the value of :py:attr:`predictionCol`.
0690         """
0691         return self._set(predictionCol=value)
0692 
0693     @since("1.6.0")
0694     def setLabelCol(self, value):
0695         """
0696         Sets the value of :py:attr:`labelCol`.
0697         """
0698         return self._set(labelCol=value)
0699 
0700     @since("1.6.0")
0701     def setWeightCol(self, value):
0702         """
0703         Sets the value of :py:attr:`weightCol`.
0704         """
0705         return self._set(weightCol=value)
0706 
0707 
0708 class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritable,
0709                               JavaMLReadable):
0710     """
0711     Model fitted by :class:`IsotonicRegression`.
0712 
0713     .. versionadded:: 1.6.0
0714     """
0715 
0716     @since("3.0.0")
0717     def setFeaturesCol(self, value):
0718         """
0719         Sets the value of :py:attr:`featuresCol`.
0720         """
0721         return self._set(featuresCol=value)
0722 
0723     @since("3.0.0")
0724     def setPredictionCol(self, value):
0725         """
0726         Sets the value of :py:attr:`predictionCol`.
0727         """
0728         return self._set(predictionCol=value)
0729 
0730     def setFeatureIndex(self, value):
0731         """
0732         Sets the value of :py:attr:`featureIndex`.
0733         """
0734         return self._set(featureIndex=value)
0735 
0736     @property
0737     @since("1.6.0")
0738     def boundaries(self):
0739         """
0740         Boundaries in increasing order for which predictions are known.
0741         """
0742         return self._call_java("boundaries")
0743 
0744     @property
0745     @since("1.6.0")
0746     def predictions(self):
0747         """
0748         Predictions associated with the boundaries at the same index, monotone because of isotonic
0749         regression.
0750         """
0751         return self._call_java("predictions")
0752 
0753     @property
0754     @since("3.0.0")
0755     def numFeatures(self):
0756         """
0757         Returns the number of features the model was trained on. If unknown, returns -1
0758         """
0759         return self._call_java("numFeatures")
0760 
0761     @since("3.0.0")
0762     def predict(self, value):
0763         """
0764         Predict label for the given features.
0765         """
0766         return self._call_java("predict", value)
0767 
0768 
0769 class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol):
0770     """
0771     Params for :py:class:`DecisionTreeRegressor` and :py:class:`DecisionTreeRegressionModel`.
0772 
0773     .. versionadded:: 3.0.0
0774     """
0775 
0776     pass
0777 
0778 
0779 @inherit_doc
0780 class DecisionTreeRegressor(JavaRegressor, _DecisionTreeRegressorParams, JavaMLWritable,
0781                             JavaMLReadable):
0782     """
0783     `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
0784     learning algorithm for regression.
0785     It supports both continuous and categorical features.
0786 
0787     >>> from pyspark.ml.linalg import Vectors
0788     >>> df = spark.createDataFrame([
0789     ...     (1.0, Vectors.dense(1.0)),
0790     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
0791     >>> dt = DecisionTreeRegressor(maxDepth=2)
0792     >>> dt.setVarianceCol("variance")
0793     DecisionTreeRegressor...
0794     >>> model = dt.fit(df)
0795     >>> model.getVarianceCol()
0796     'variance'
0797     >>> model.setLeafCol("leafId")
0798     DecisionTreeRegressionModel...
0799     >>> model.depth
0800     1
0801     >>> model.numNodes
0802     3
0803     >>> model.featureImportances
0804     SparseVector(1, {0: 1.0})
0805     >>> model.numFeatures
0806     1
0807     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
0808     >>> model.predict(test0.head().features)
0809     0.0
0810     >>> result = model.transform(test0).head()
0811     >>> result.prediction
0812     0.0
0813     >>> model.predictLeaf(test0.head().features)
0814     0.0
0815     >>> result.leafId
0816     0.0
0817     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
0818     >>> model.transform(test1).head().prediction
0819     1.0
0820     >>> dtr_path = temp_path + "/dtr"
0821     >>> dt.save(dtr_path)
0822     >>> dt2 = DecisionTreeRegressor.load(dtr_path)
0823     >>> dt2.getMaxDepth()
0824     2
0825     >>> model_path = temp_path + "/dtr_model"
0826     >>> model.save(model_path)
0827     >>> model2 = DecisionTreeRegressionModel.load(model_path)
0828     >>> model.numNodes == model2.numNodes
0829     True
0830     >>> model.depth == model2.depth
0831     True
0832     >>> model.transform(test1).head().variance
0833     0.0
0834 
0835     >>> df3 = spark.createDataFrame([
0836     ...     (1.0, 0.2, Vectors.dense(1.0)),
0837     ...     (1.0, 0.8, Vectors.dense(1.0)),
0838     ...     (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
0839     >>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol="variance")
0840     >>> model3 = dt3.fit(df3)
0841     >>> print(model3.toDebugString)
0842     DecisionTreeRegressionModel...depth=1, numNodes=3...
0843 
0844     .. versionadded:: 1.4.0
0845     """
0846 
0847     @keyword_only
0848     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0849                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
0850                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
0851                  seed=None, varianceCol=None, weightCol=None, leafCol="",
0852                  minWeightFractionPerNode=0.0):
0853         """
0854         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0855                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
0856                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
0857                  impurity="variance", seed=None, varianceCol=None, weightCol=None, \
0858                  leafCol="", minWeightFractionPerNode=0.0)
0859         """
0860         super(DecisionTreeRegressor, self).__init__()
0861         self._java_obj = self._new_java_obj(
0862             "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
0863         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
0864                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
0865                          impurity="variance", leafCol="", minWeightFractionPerNode=0.0)
0866         kwargs = self._input_kwargs
0867         self.setParams(**kwargs)
0868 
0869     @keyword_only
0870     @since("1.4.0")
0871     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0872                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
0873                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
0874                   impurity="variance", seed=None, varianceCol=None, weightCol=None,
0875                   leafCol="", minWeightFractionPerNode=0.0):
0876         """
0877         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0878                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
0879                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
0880                   impurity="variance", seed=None, varianceCol=None, weightCol=None, \
0881                   leafCol="", minWeightFractionPerNode=0.0)
0882         Sets params for the DecisionTreeRegressor.
0883         """
0884         kwargs = self._input_kwargs
0885         return self._set(**kwargs)
0886 
0887     def _create_model(self, java_model):
0888         return DecisionTreeRegressionModel(java_model)
0889 
0890     @since("1.4.0")
0891     def setMaxDepth(self, value):
0892         """
0893         Sets the value of :py:attr:`maxDepth`.
0894         """
0895         return self._set(maxDepth=value)
0896 
0897     @since("1.4.0")
0898     def setMaxBins(self, value):
0899         """
0900         Sets the value of :py:attr:`maxBins`.
0901         """
0902         return self._set(maxBins=value)
0903 
0904     @since("1.4.0")
0905     def setMinInstancesPerNode(self, value):
0906         """
0907         Sets the value of :py:attr:`minInstancesPerNode`.
0908         """
0909         return self._set(minInstancesPerNode=value)
0910 
0911     @since("3.0.0")
0912     def setMinWeightFractionPerNode(self, value):
0913         """
0914         Sets the value of :py:attr:`minWeightFractionPerNode`.
0915         """
0916         return self._set(minWeightFractionPerNode=value)
0917 
0918     @since("1.4.0")
0919     def setMinInfoGain(self, value):
0920         """
0921         Sets the value of :py:attr:`minInfoGain`.
0922         """
0923         return self._set(minInfoGain=value)
0924 
0925     @since("1.4.0")
0926     def setMaxMemoryInMB(self, value):
0927         """
0928         Sets the value of :py:attr:`maxMemoryInMB`.
0929         """
0930         return self._set(maxMemoryInMB=value)
0931 
0932     @since("1.4.0")
0933     def setCacheNodeIds(self, value):
0934         """
0935         Sets the value of :py:attr:`cacheNodeIds`.
0936         """
0937         return self._set(cacheNodeIds=value)
0938 
0939     @since("1.4.0")
0940     def setImpurity(self, value):
0941         """
0942         Sets the value of :py:attr:`impurity`.
0943         """
0944         return self._set(impurity=value)
0945 
0946     @since("1.4.0")
0947     def setCheckpointInterval(self, value):
0948         """
0949         Sets the value of :py:attr:`checkpointInterval`.
0950         """
0951         return self._set(checkpointInterval=value)
0952 
0953     def setSeed(self, value):
0954         """
0955         Sets the value of :py:attr:`seed`.
0956         """
0957         return self._set(seed=value)
0958 
0959     @since("3.0.0")
0960     def setWeightCol(self, value):
0961         """
0962         Sets the value of :py:attr:`weightCol`.
0963         """
0964         return self._set(weightCol=value)
0965 
0966     @since("2.0.0")
0967     def setVarianceCol(self, value):
0968         """
0969         Sets the value of :py:attr:`varianceCol`.
0970         """
0971         return self._set(varianceCol=value)
0972 
0973 
0974 @inherit_doc
0975 class DecisionTreeRegressionModel(
0976     JavaRegressionModel, _DecisionTreeModel, _DecisionTreeRegressorParams,
0977     JavaMLWritable, JavaMLReadable
0978 ):
0979     """
0980     Model fitted by :class:`DecisionTreeRegressor`.
0981 
0982     .. versionadded:: 1.4.0
0983     """
0984 
0985     @since("3.0.0")
0986     def setVarianceCol(self, value):
0987         """
0988         Sets the value of :py:attr:`varianceCol`.
0989         """
0990         return self._set(varianceCol=value)
0991 
0992     @property
0993     @since("2.0.0")
0994     def featureImportances(self):
0995         """
0996         Estimate of the importance of each feature.
0997 
0998         This generalizes the idea of "Gini" importance to other losses,
0999         following the explanation of Gini importance from "Random Forests" documentation
1000         by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
1001 
1002         This feature importance is calculated as follows:
1003           - importance(feature j) = sum (over nodes which split on feature j) of the gain,
1004             where gain is scaled by the number of instances passing through node
1005           - Normalize importances for tree to sum to 1.
1006 
1007         .. note:: Feature importance for single decision trees can have high variance due to
1008               correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
1009               to determine feature importance instead.
1010         """
1011         return self._call_java("featureImportances")
1012 
1013 
1014 class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
1015     """
1016     Params for :py:class:`RandomForestRegressor` and :py:class:`RandomForestRegressionModel`.
1017 
1018     .. versionadded:: 3.0.0
1019     """
1020     pass
1021 
1022 
1023 @inherit_doc
1024 class RandomForestRegressor(JavaRegressor, _RandomForestRegressorParams, JavaMLWritable,
1025                             JavaMLReadable):
1026     """
1027     `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
1028     learning algorithm for regression.
1029     It supports both continuous and categorical features.
1030 
1031     >>> from numpy import allclose
1032     >>> from pyspark.ml.linalg import Vectors
1033     >>> df = spark.createDataFrame([
1034     ...     (1.0, Vectors.dense(1.0)),
1035     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1036     >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
1037     >>> rf.getMinWeightFractionPerNode()
1038     0.0
1039     >>> rf.setSeed(42)
1040     RandomForestRegressor...
1041     >>> model = rf.fit(df)
1042     >>> model.getBootstrap()
1043     True
1044     >>> model.getSeed()
1045     42
1046     >>> model.setLeafCol("leafId")
1047     RandomForestRegressionModel...
1048     >>> model.featureImportances
1049     SparseVector(1, {0: 1.0})
1050     >>> allclose(model.treeWeights, [1.0, 1.0])
1051     True
1052     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1053     >>> model.predict(test0.head().features)
1054     0.0
1055     >>> model.predictLeaf(test0.head().features)
1056     DenseVector([0.0, 0.0])
1057     >>> result = model.transform(test0).head()
1058     >>> result.prediction
1059     0.0
1060     >>> result.leafId
1061     DenseVector([0.0, 0.0])
1062     >>> model.numFeatures
1063     1
1064     >>> model.trees
1065     [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1066     >>> model.getNumTrees
1067     2
1068     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1069     >>> model.transform(test1).head().prediction
1070     0.5
1071     >>> rfr_path = temp_path + "/rfr"
1072     >>> rf.save(rfr_path)
1073     >>> rf2 = RandomForestRegressor.load(rfr_path)
1074     >>> rf2.getNumTrees()
1075     2
1076     >>> model_path = temp_path + "/rfr_model"
1077     >>> model.save(model_path)
1078     >>> model2 = RandomForestRegressionModel.load(model_path)
1079     >>> model.featureImportances == model2.featureImportances
1080     True
1081 
1082     .. versionadded:: 1.4.0
1083     """
1084 
1085     @keyword_only
1086     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1087                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1088                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1089                  impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
1090                  featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1091                  weightCol=None, bootstrap=True):
1092         """
1093         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1094                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1095                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1096                  impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1097                  featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0", \
1098                  weightCol=None, bootstrap=True)
1099         """
1100         super(RandomForestRegressor, self).__init__()
1101         self._java_obj = self._new_java_obj(
1102             "org.apache.spark.ml.regression.RandomForestRegressor", self.uid)
1103         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1104                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1105                          impurity="variance", subsamplingRate=1.0, numTrees=20,
1106                          featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1107                          bootstrap=True)
1108         kwargs = self._input_kwargs
1109         self.setParams(**kwargs)
1110 
1111     @keyword_only
1112     @since("1.4.0")
1113     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1114                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1115                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1116                   impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
1117                   featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1118                   weightCol=None, bootstrap=True):
1119         """
1120         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1121                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1122                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1123                   impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1124                   featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0, \
1125                   weightCol=None, bootstrap=True)
1126         Sets params for linear regression.
1127         """
1128         kwargs = self._input_kwargs
1129         return self._set(**kwargs)
1130 
1131     def _create_model(self, java_model):
1132         return RandomForestRegressionModel(java_model)
1133 
1134     def setMaxDepth(self, value):
1135         """
1136         Sets the value of :py:attr:`maxDepth`.
1137         """
1138         return self._set(maxDepth=value)
1139 
1140     def setMaxBins(self, value):
1141         """
1142         Sets the value of :py:attr:`maxBins`.
1143         """
1144         return self._set(maxBins=value)
1145 
1146     def setMinInstancesPerNode(self, value):
1147         """
1148         Sets the value of :py:attr:`minInstancesPerNode`.
1149         """
1150         return self._set(minInstancesPerNode=value)
1151 
1152     def setMinInfoGain(self, value):
1153         """
1154         Sets the value of :py:attr:`minInfoGain`.
1155         """
1156         return self._set(minInfoGain=value)
1157 
1158     def setMaxMemoryInMB(self, value):
1159         """
1160         Sets the value of :py:attr:`maxMemoryInMB`.
1161         """
1162         return self._set(maxMemoryInMB=value)
1163 
1164     def setCacheNodeIds(self, value):
1165         """
1166         Sets the value of :py:attr:`cacheNodeIds`.
1167         """
1168         return self._set(cacheNodeIds=value)
1169 
1170     @since("1.4.0")
1171     def setImpurity(self, value):
1172         """
1173         Sets the value of :py:attr:`impurity`.
1174         """
1175         return self._set(impurity=value)
1176 
1177     @since("1.4.0")
1178     def setNumTrees(self, value):
1179         """
1180         Sets the value of :py:attr:`numTrees`.
1181         """
1182         return self._set(numTrees=value)
1183 
1184     @since("3.0.0")
1185     def setBootstrap(self, value):
1186         """
1187         Sets the value of :py:attr:`bootstrap`.
1188         """
1189         return self._set(bootstrap=value)
1190 
1191     @since("1.4.0")
1192     def setSubsamplingRate(self, value):
1193         """
1194         Sets the value of :py:attr:`subsamplingRate`.
1195         """
1196         return self._set(subsamplingRate=value)
1197 
1198     @since("2.4.0")
1199     def setFeatureSubsetStrategy(self, value):
1200         """
1201         Sets the value of :py:attr:`featureSubsetStrategy`.
1202         """
1203         return self._set(featureSubsetStrategy=value)
1204 
1205     def setCheckpointInterval(self, value):
1206         """
1207         Sets the value of :py:attr:`checkpointInterval`.
1208         """
1209         return self._set(checkpointInterval=value)
1210 
1211     def setSeed(self, value):
1212         """
1213         Sets the value of :py:attr:`seed`.
1214         """
1215         return self._set(seed=value)
1216 
1217     @since("3.0.0")
1218     def setWeightCol(self, value):
1219         """
1220         Sets the value of :py:attr:`weightCol`.
1221         """
1222         return self._set(weightCol=value)
1223 
1224     @since("3.0.0")
1225     def setMinWeightFractionPerNode(self, value):
1226         """
1227         Sets the value of :py:attr:`minWeightFractionPerNode`.
1228         """
1229         return self._set(minWeightFractionPerNode=value)
1230 
1231 
1232 class RandomForestRegressionModel(
1233     JavaRegressionModel, _TreeEnsembleModel, _RandomForestRegressorParams,
1234     JavaMLWritable, JavaMLReadable
1235 ):
1236     """
1237     Model fitted by :class:`RandomForestRegressor`.
1238 
1239     .. versionadded:: 1.4.0
1240     """
1241 
1242     @property
1243     @since("2.0.0")
1244     def trees(self):
1245         """Trees in this ensemble. Warning: These have null parent Estimators."""
1246         return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1247 
1248     @property
1249     @since("2.0.0")
1250     def featureImportances(self):
1251         """
1252         Estimate of the importance of each feature.
1253 
1254         Each feature's importance is the average of its importance across all trees in the ensemble
1255         The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1256         (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1257         and follows the implementation from scikit-learn.
1258 
1259         .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
1260         """
1261         return self._call_java("featureImportances")
1262 
1263 
1264 class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
1265     """
1266     Params for :py:class:`GBTRegressor` and :py:class:`GBTRegressorModel`.
1267 
1268     .. versionadded:: 3.0.0
1269     """
1270 
1271     supportedLossTypes = ["squared", "absolute"]
1272 
1273     lossType = Param(Params._dummy(), "lossType",
1274                      "Loss function which GBT tries to minimize (case-insensitive). " +
1275                      "Supported options: " + ", ".join(supportedLossTypes),
1276                      typeConverter=TypeConverters.toString)
1277 
1278     @since("1.4.0")
1279     def getLossType(self):
1280         """
1281         Gets the value of lossType or its default value.
1282         """
1283         return self.getOrDefault(self.lossType)
1284 
1285 
1286 @inherit_doc
1287 class GBTRegressor(JavaRegressor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
1288     """
1289     `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
1290     learning algorithm for regression.
1291     It supports both continuous and categorical features.
1292 
1293     >>> from numpy import allclose
1294     >>> from pyspark.ml.linalg import Vectors
1295     >>> df = spark.createDataFrame([
1296     ...     (1.0, Vectors.dense(1.0)),
1297     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1298     >>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId")
1299     >>> gbt.setMaxIter(5)
1300     GBTRegressor...
1301     >>> gbt.setMinWeightFractionPerNode(0.049)
1302     GBTRegressor...
1303     >>> gbt.getMaxIter()
1304     5
1305     >>> print(gbt.getImpurity())
1306     variance
1307     >>> print(gbt.getFeatureSubsetStrategy())
1308     all
1309     >>> model = gbt.fit(df)
1310     >>> model.featureImportances
1311     SparseVector(1, {0: 1.0})
1312     >>> model.numFeatures
1313     1
1314     >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
1315     True
1316     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1317     >>> model.predict(test0.head().features)
1318     0.0
1319     >>> model.predictLeaf(test0.head().features)
1320     DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1321     >>> result = model.transform(test0).head()
1322     >>> result.prediction
1323     0.0
1324     >>> result.leafId
1325     DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1326     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1327     >>> model.transform(test1).head().prediction
1328     1.0
1329     >>> gbtr_path = temp_path + "gbtr"
1330     >>> gbt.save(gbtr_path)
1331     >>> gbt2 = GBTRegressor.load(gbtr_path)
1332     >>> gbt2.getMaxDepth()
1333     2
1334     >>> model_path = temp_path + "gbtr_model"
1335     >>> model.save(model_path)
1336     >>> model2 = GBTRegressionModel.load(model_path)
1337     >>> model.featureImportances == model2.featureImportances
1338     True
1339     >>> model.treeWeights == model2.treeWeights
1340     True
1341     >>> model.trees
1342     [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1343     >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
1344     ...              ["label", "features"])
1345     >>> model.evaluateEachIteration(validation, "squared")
1346     [0.0, 0.0, 0.0, 0.0, 0.0]
1347     >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
1348     >>> gbt.getValidationIndicatorCol()
1349     'validationIndicator'
1350     >>> gbt.getValidationTol()
1351     0.01
1352 
1353     .. versionadded:: 1.4.0
1354     """
1355 
1356     @keyword_only
1357     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1358                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1359                  maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
1360                  checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
1361                  impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1362                  validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1363                  weightCol=None):
1364         """
1365         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1366                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1367                  maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
1368                  checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
1369                  impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1370                  validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1371                  weightCol=None)
1372         """
1373         super(GBTRegressor, self).__init__()
1374         self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
1375         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1376                          maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
1377                          checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
1378                          impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1379                          leafCol="", minWeightFractionPerNode=0.0)
1380         kwargs = self._input_kwargs
1381         self.setParams(**kwargs)
1382 
1383     @keyword_only
1384     @since("1.4.0")
1385     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1386                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1387                   maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
1388                   checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
1389                   impuriy="variance", featureSubsetStrategy="all", validationTol=0.01,
1390                   validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1391                   weightCol=None):
1392         """
1393         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1394                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1395                   maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
1396                   checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
1397                   impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1398                   validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1399                   weightCol=None)
1400         Sets params for Gradient Boosted Tree Regression.
1401         """
1402         kwargs = self._input_kwargs
1403         return self._set(**kwargs)
1404 
1405     def _create_model(self, java_model):
1406         return GBTRegressionModel(java_model)
1407 
1408     @since("1.4.0")
1409     def setMaxDepth(self, value):
1410         """
1411         Sets the value of :py:attr:`maxDepth`.
1412         """
1413         return self._set(maxDepth=value)
1414 
1415     @since("1.4.0")
1416     def setMaxBins(self, value):
1417         """
1418         Sets the value of :py:attr:`maxBins`.
1419         """
1420         return self._set(maxBins=value)
1421 
1422     @since("1.4.0")
1423     def setMinInstancesPerNode(self, value):
1424         """
1425         Sets the value of :py:attr:`minInstancesPerNode`.
1426         """
1427         return self._set(minInstancesPerNode=value)
1428 
1429     @since("1.4.0")
1430     def setMinInfoGain(self, value):
1431         """
1432         Sets the value of :py:attr:`minInfoGain`.
1433         """
1434         return self._set(minInfoGain=value)
1435 
1436     @since("1.4.0")
1437     def setMaxMemoryInMB(self, value):
1438         """
1439         Sets the value of :py:attr:`maxMemoryInMB`.
1440         """
1441         return self._set(maxMemoryInMB=value)
1442 
1443     @since("1.4.0")
1444     def setCacheNodeIds(self, value):
1445         """
1446         Sets the value of :py:attr:`cacheNodeIds`.
1447         """
1448         return self._set(cacheNodeIds=value)
1449 
1450     @since("1.4.0")
1451     def setImpurity(self, value):
1452         """
1453         Sets the value of :py:attr:`impurity`.
1454         """
1455         return self._set(impurity=value)
1456 
1457     @since("1.4.0")
1458     def setLossType(self, value):
1459         """
1460         Sets the value of :py:attr:`lossType`.
1461         """
1462         return self._set(lossType=value)
1463 
1464     @since("1.4.0")
1465     def setSubsamplingRate(self, value):
1466         """
1467         Sets the value of :py:attr:`subsamplingRate`.
1468         """
1469         return self._set(subsamplingRate=value)
1470 
1471     @since("2.4.0")
1472     def setFeatureSubsetStrategy(self, value):
1473         """
1474         Sets the value of :py:attr:`featureSubsetStrategy`.
1475         """
1476         return self._set(featureSubsetStrategy=value)
1477 
1478     @since("3.0.0")
1479     def setValidationIndicatorCol(self, value):
1480         """
1481         Sets the value of :py:attr:`validationIndicatorCol`.
1482         """
1483         return self._set(validationIndicatorCol=value)
1484 
1485     @since("1.4.0")
1486     def setMaxIter(self, value):
1487         """
1488         Sets the value of :py:attr:`maxIter`.
1489         """
1490         return self._set(maxIter=value)
1491 
1492     @since("1.4.0")
1493     def setCheckpointInterval(self, value):
1494         """
1495         Sets the value of :py:attr:`checkpointInterval`.
1496         """
1497         return self._set(checkpointInterval=value)
1498 
1499     @since("1.4.0")
1500     def setSeed(self, value):
1501         """
1502         Sets the value of :py:attr:`seed`.
1503         """
1504         return self._set(seed=value)
1505 
1506     @since("1.4.0")
1507     def setStepSize(self, value):
1508         """
1509         Sets the value of :py:attr:`stepSize`.
1510         """
1511         return self._set(stepSize=value)
1512 
1513     @since("3.0.0")
1514     def setWeightCol(self, value):
1515         """
1516         Sets the value of :py:attr:`weightCol`.
1517         """
1518         return self._set(weightCol=value)
1519 
1520     @since("3.0.0")
1521     def setMinWeightFractionPerNode(self, value):
1522         """
1523         Sets the value of :py:attr:`minWeightFractionPerNode`.
1524         """
1525         return self._set(minWeightFractionPerNode=value)
1526 
1527 
1528 class GBTRegressionModel(
1529     JavaRegressionModel, _TreeEnsembleModel, _GBTRegressorParams,
1530     JavaMLWritable, JavaMLReadable
1531 ):
1532     """
1533     Model fitted by :class:`GBTRegressor`.
1534 
1535     .. versionadded:: 1.4.0
1536     """
1537 
1538     @property
1539     @since("2.0.0")
1540     def featureImportances(self):
1541         """
1542         Estimate of the importance of each feature.
1543 
1544         Each feature's importance is the average of its importance across all trees in the ensemble
1545         The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1546         (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1547         and follows the implementation from scikit-learn.
1548 
1549         .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
1550         """
1551         return self._call_java("featureImportances")
1552 
1553     @property
1554     @since("2.0.0")
1555     def trees(self):
1556         """Trees in this ensemble. Warning: These have null parent Estimators."""
1557         return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1558 
1559     @since("2.4.0")
1560     def evaluateEachIteration(self, dataset, loss):
1561         """
1562         Method to compute error or loss for every iteration of gradient boosting.
1563 
1564         :param dataset:
1565             Test dataset to evaluate model on, where dataset is an
1566             instance of :py:class:`pyspark.sql.DataFrame`
1567         :param loss:
1568             The loss function used to compute error.
1569             Supported options: squared, absolute
1570         """
1571         return self._call_java("evaluateEachIteration", dataset, loss)
1572 
1573 
1574 class _AFTSurvivalRegressionParams(_JavaPredictorParams, HasMaxIter, HasTol, HasFitIntercept,
1575                                    HasAggregationDepth):
1576     """
1577     Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
1578 
1579     .. versionadded:: 3.0.0
1580     """
1581 
1582     censorCol = Param(
1583         Params._dummy(), "censorCol",
1584         "censor column name. The value of this column could be 0 or 1. " +
1585         "If the value is 1, it means the event has occurred i.e. " +
1586         "uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
1587     quantileProbabilities = Param(
1588         Params._dummy(), "quantileProbabilities",
1589         "quantile probabilities array. Values of the quantile probabilities array " +
1590         "should be in the range (0, 1) and the array should be non-empty.",
1591         typeConverter=TypeConverters.toListFloat)
1592     quantilesCol = Param(
1593         Params._dummy(), "quantilesCol",
1594         "quantiles column name. This column will output quantiles of " +
1595         "corresponding quantileProbabilities if it is set.",
1596         typeConverter=TypeConverters.toString)
1597 
1598     @since("1.6.0")
1599     def getCensorCol(self):
1600         """
1601         Gets the value of censorCol or its default value.
1602         """
1603         return self.getOrDefault(self.censorCol)
1604 
1605     @since("1.6.0")
1606     def getQuantileProbabilities(self):
1607         """
1608         Gets the value of quantileProbabilities or its default value.
1609         """
1610         return self.getOrDefault(self.quantileProbabilities)
1611 
1612     @since("1.6.0")
1613     def getQuantilesCol(self):
1614         """
1615         Gets the value of quantilesCol or its default value.
1616         """
1617         return self.getOrDefault(self.quantilesCol)
1618 
1619 
1620 @inherit_doc
1621 class AFTSurvivalRegression(JavaRegressor, _AFTSurvivalRegressionParams,
1622                             JavaMLWritable, JavaMLReadable):
1623     """
1624     Accelerated Failure Time (AFT) Model Survival Regression
1625 
1626     Fit a parametric AFT survival regression model based on the Weibull distribution
1627     of the survival time.
1628 
1629     .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
1630 
1631     >>> from pyspark.ml.linalg import Vectors
1632     >>> df = spark.createDataFrame([
1633     ...     (1.0, Vectors.dense(1.0), 1.0),
1634     ...     (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
1635     >>> aftsr = AFTSurvivalRegression()
1636     >>> aftsr.setMaxIter(10)
1637     AFTSurvivalRegression...
1638     >>> aftsr.getMaxIter()
1639     10
1640     >>> aftsr.clear(aftsr.maxIter)
1641     >>> model = aftsr.fit(df)
1642     >>> model.setFeaturesCol("features")
1643     AFTSurvivalRegressionModel...
1644     >>> model.predict(Vectors.dense(6.3))
1645     1.0
1646     >>> model.predictQuantiles(Vectors.dense(6.3))
1647     DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
1648     >>> model.transform(df).show()
1649     +-------+---------+------+----------+
1650     |  label| features|censor|prediction|
1651     +-------+---------+------+----------+
1652     |    1.0|    [1.0]|   1.0|       1.0|
1653     |1.0E-40|(1,[],[])|   0.0|       1.0|
1654     +-------+---------+------+----------+
1655     ...
1656     >>> aftsr_path = temp_path + "/aftsr"
1657     >>> aftsr.save(aftsr_path)
1658     >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
1659     >>> aftsr2.getMaxIter()
1660     100
1661     >>> model_path = temp_path + "/aftsr_model"
1662     >>> model.save(model_path)
1663     >>> model2 = AFTSurvivalRegressionModel.load(model_path)
1664     >>> model.coefficients == model2.coefficients
1665     True
1666     >>> model.intercept == model2.intercept
1667     True
1668     >>> model.scale == model2.scale
1669     True
1670 
1671     .. versionadded:: 1.6.0
1672     """
1673 
1674     @keyword_only
1675     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1676                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
1677                  quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
1678                  quantilesCol=None, aggregationDepth=2):
1679         """
1680         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1681                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
1682                  quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
1683                  quantilesCol=None, aggregationDepth=2)
1684         """
1685         super(AFTSurvivalRegression, self).__init__()
1686         self._java_obj = self._new_java_obj(
1687             "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
1688         self._setDefault(censorCol="censor",
1689                          quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
1690                          maxIter=100, tol=1E-6)
1691         kwargs = self._input_kwargs
1692         self.setParams(**kwargs)
1693 
1694     @keyword_only
1695     @since("1.6.0")
1696     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1697                   fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
1698                   quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
1699                   quantilesCol=None, aggregationDepth=2):
1700         """
1701         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1702                   fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
1703                   quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
1704                   quantilesCol=None, aggregationDepth=2):
1705         """
1706         kwargs = self._input_kwargs
1707         return self._set(**kwargs)
1708 
1709     def _create_model(self, java_model):
1710         return AFTSurvivalRegressionModel(java_model)
1711 
1712     @since("1.6.0")
1713     def setCensorCol(self, value):
1714         """
1715         Sets the value of :py:attr:`censorCol`.
1716         """
1717         return self._set(censorCol=value)
1718 
1719     @since("1.6.0")
1720     def setQuantileProbabilities(self, value):
1721         """
1722         Sets the value of :py:attr:`quantileProbabilities`.
1723         """
1724         return self._set(quantileProbabilities=value)
1725 
1726     @since("1.6.0")
1727     def setQuantilesCol(self, value):
1728         """
1729         Sets the value of :py:attr:`quantilesCol`.
1730         """
1731         return self._set(quantilesCol=value)
1732 
1733     @since("1.6.0")
1734     def setMaxIter(self, value):
1735         """
1736         Sets the value of :py:attr:`maxIter`.
1737         """
1738         return self._set(maxIter=value)
1739 
1740     @since("1.6.0")
1741     def setTol(self, value):
1742         """
1743         Sets the value of :py:attr:`tol`.
1744         """
1745         return self._set(tol=value)
1746 
1747     @since("1.6.0")
1748     def setFitIntercept(self, value):
1749         """
1750         Sets the value of :py:attr:`fitIntercept`.
1751         """
1752         return self._set(fitIntercept=value)
1753 
1754     @since("2.1.0")
1755     def setAggregationDepth(self, value):
1756         """
1757         Sets the value of :py:attr:`aggregationDepth`.
1758         """
1759         return self._set(aggregationDepth=value)
1760 
1761 
1762 class AFTSurvivalRegressionModel(JavaRegressionModel, _AFTSurvivalRegressionParams,
1763                                  JavaMLWritable, JavaMLReadable):
1764     """
1765     Model fitted by :class:`AFTSurvivalRegression`.
1766 
1767     .. versionadded:: 1.6.0
1768     """
1769 
1770     @since("3.0.0")
1771     def setQuantileProbabilities(self, value):
1772         """
1773         Sets the value of :py:attr:`quantileProbabilities`.
1774         """
1775         return self._set(quantileProbabilities=value)
1776 
1777     @since("3.0.0")
1778     def setQuantilesCol(self, value):
1779         """
1780         Sets the value of :py:attr:`quantilesCol`.
1781         """
1782         return self._set(quantilesCol=value)
1783 
1784     @property
1785     @since("2.0.0")
1786     def coefficients(self):
1787         """
1788         Model coefficients.
1789         """
1790         return self._call_java("coefficients")
1791 
1792     @property
1793     @since("1.6.0")
1794     def intercept(self):
1795         """
1796         Model intercept.
1797         """
1798         return self._call_java("intercept")
1799 
1800     @property
1801     @since("1.6.0")
1802     def scale(self):
1803         """
1804         Model scale parameter.
1805         """
1806         return self._call_java("scale")
1807 
1808     @since("2.0.0")
1809     def predictQuantiles(self, features):
1810         """
1811         Predicted Quantiles
1812         """
1813         return self._call_java("predictQuantiles", features)
1814 
1815 
1816 class _GeneralizedLinearRegressionParams(_JavaPredictorParams, HasFitIntercept, HasMaxIter,
1817                                          HasTol, HasRegParam, HasWeightCol, HasSolver,
1818                                          HasAggregationDepth):
1819     """
1820     Params for :py:class:`GeneralizedLinearRegression` and
1821     :py:class:`GeneralizedLinearRegressionModel`.
1822 
1823     .. versionadded:: 3.0.0
1824     """
1825 
1826     family = Param(Params._dummy(), "family", "The name of family which is a description of " +
1827                    "the error distribution to be used in the model. Supported options: " +
1828                    "gaussian (default), binomial, poisson, gamma and tweedie.",
1829                    typeConverter=TypeConverters.toString)
1830     link = Param(Params._dummy(), "link", "The name of link function which provides the " +
1831                  "relationship between the linear predictor and the mean of the distribution " +
1832                  "function. Supported options: identity, log, inverse, logit, probit, cloglog " +
1833                  "and sqrt.", typeConverter=TypeConverters.toString)
1834     linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
1835                               "predictor) column name", typeConverter=TypeConverters.toString)
1836     variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " +
1837                           "of the Tweedie distribution which characterizes the relationship " +
1838                           "between the variance and mean of the distribution. Only applicable " +
1839                           "for the Tweedie family. Supported values: 0 and [1, Inf).",
1840                           typeConverter=TypeConverters.toFloat)
1841     linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
1842                       "Only applicable to the Tweedie family.",
1843                       typeConverter=TypeConverters.toFloat)
1844     solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
1845                    "options: irls.", typeConverter=TypeConverters.toString)
1846     offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " +
1847                       "or empty, we treat all instance offsets as 0.0",
1848                       typeConverter=TypeConverters.toString)
1849 
1850     @since("2.0.0")
1851     def getFamily(self):
1852         """
1853         Gets the value of family or its default value.
1854         """
1855         return self.getOrDefault(self.family)
1856 
1857     @since("2.0.0")
1858     def getLinkPredictionCol(self):
1859         """
1860         Gets the value of linkPredictionCol or its default value.
1861         """
1862         return self.getOrDefault(self.linkPredictionCol)
1863 
1864     @since("2.0.0")
1865     def getLink(self):
1866         """
1867         Gets the value of link or its default value.
1868         """
1869         return self.getOrDefault(self.link)
1870 
1871     @since("2.2.0")
1872     def getVariancePower(self):
1873         """
1874         Gets the value of variancePower or its default value.
1875         """
1876         return self.getOrDefault(self.variancePower)
1877 
1878     @since("2.2.0")
1879     def getLinkPower(self):
1880         """
1881         Gets the value of linkPower or its default value.
1882         """
1883         return self.getOrDefault(self.linkPower)
1884 
1885     @since("2.3.0")
1886     def getOffsetCol(self):
1887         """
1888         Gets the value of offsetCol or its default value.
1889         """
1890         return self.getOrDefault(self.offsetCol)
1891 
1892 
1893 @inherit_doc
1894 class GeneralizedLinearRegression(JavaRegressor, _GeneralizedLinearRegressionParams,
1895                                   JavaMLWritable, JavaMLReadable):
1896     """
1897     Generalized Linear Regression.
1898 
1899     Fit a Generalized Linear Model specified by giving a symbolic description of the linear
1900     predictor (link function) and a description of the error distribution (family). It supports
1901     "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
1902     each family is listed below. The first link function of each family is the default one.
1903 
1904     * "gaussian" -> "identity", "log", "inverse"
1905 
1906     * "binomial" -> "logit", "probit", "cloglog"
1907 
1908     * "poisson"  -> "log", "identity", "sqrt"
1909 
1910     * "gamma"    -> "inverse", "identity", "log"
1911 
1912     * "tweedie"  -> power link function specified through "linkPower". \
1913                     The default link power in the tweedie family is 1 - variancePower.
1914 
1915     .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
1916 
1917     >>> from pyspark.ml.linalg import Vectors
1918     >>> df = spark.createDataFrame([
1919     ...     (1.0, Vectors.dense(0.0, 0.0)),
1920     ...     (1.0, Vectors.dense(1.0, 2.0)),
1921     ...     (2.0, Vectors.dense(0.0, 0.0)),
1922     ...     (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
1923     >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
1924     >>> glr.setRegParam(0.1)
1925     GeneralizedLinearRegression...
1926     >>> glr.getRegParam()
1927     0.1
1928     >>> glr.clear(glr.regParam)
1929     >>> glr.setMaxIter(10)
1930     GeneralizedLinearRegression...
1931     >>> glr.getMaxIter()
1932     10
1933     >>> glr.clear(glr.maxIter)
1934     >>> model = glr.fit(df)
1935     >>> model.setFeaturesCol("features")
1936     GeneralizedLinearRegressionModel...
1937     >>> model.getMaxIter()
1938     25
1939     >>> model.getAggregationDepth()
1940     2
1941     >>> transformed = model.transform(df)
1942     >>> abs(transformed.head().prediction - 1.5) < 0.001
1943     True
1944     >>> abs(transformed.head().p - 1.5) < 0.001
1945     True
1946     >>> model.coefficients
1947     DenseVector([1.5..., -1.0...])
1948     >>> model.numFeatures
1949     2
1950     >>> abs(model.intercept - 1.5) < 0.001
1951     True
1952     >>> glr_path = temp_path + "/glr"
1953     >>> glr.save(glr_path)
1954     >>> glr2 = GeneralizedLinearRegression.load(glr_path)
1955     >>> glr.getFamily() == glr2.getFamily()
1956     True
1957     >>> model_path = temp_path + "/glr_model"
1958     >>> model.save(model_path)
1959     >>> model2 = GeneralizedLinearRegressionModel.load(model_path)
1960     >>> model.intercept == model2.intercept
1961     True
1962     >>> model.coefficients[0] == model2.coefficients[0]
1963     True
1964 
1965     .. versionadded:: 2.0.0
1966     """
1967 
1968     @keyword_only
1969     def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
1970                  family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1971                  regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1972                  variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2):
1973         """
1974         __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
1975                  family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1976                  regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1977                  variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
1978         """
1979         super(GeneralizedLinearRegression, self).__init__()
1980         self._java_obj = self._new_java_obj(
1981             "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
1982         self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
1983                          variancePower=0.0, aggregationDepth=2)
1984         kwargs = self._input_kwargs
1985 
1986         self.setParams(**kwargs)
1987 
1988     @keyword_only
1989     @since("2.0.0")
1990     def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
1991                   family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1992                   regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1993                   variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2):
1994         """
1995         setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
1996                   family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1997                   regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1998                   variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
1999         Sets params for generalized linear regression.
2000         """
2001         kwargs = self._input_kwargs
2002         return self._set(**kwargs)
2003 
2004     def _create_model(self, java_model):
2005         return GeneralizedLinearRegressionModel(java_model)
2006 
2007     @since("2.0.0")
2008     def setFamily(self, value):
2009         """
2010         Sets the value of :py:attr:`family`.
2011         """
2012         return self._set(family=value)
2013 
2014     @since("2.0.0")
2015     def setLinkPredictionCol(self, value):
2016         """
2017         Sets the value of :py:attr:`linkPredictionCol`.
2018         """
2019         return self._set(linkPredictionCol=value)
2020 
2021     @since("2.0.0")
2022     def setLink(self, value):
2023         """
2024         Sets the value of :py:attr:`link`.
2025         """
2026         return self._set(link=value)
2027 
2028     @since("2.2.0")
2029     def setVariancePower(self, value):
2030         """
2031         Sets the value of :py:attr:`variancePower`.
2032         """
2033         return self._set(variancePower=value)
2034 
2035     @since("2.2.0")
2036     def setLinkPower(self, value):
2037         """
2038         Sets the value of :py:attr:`linkPower`.
2039         """
2040         return self._set(linkPower=value)
2041 
2042     @since("2.3.0")
2043     def setOffsetCol(self, value):
2044         """
2045         Sets the value of :py:attr:`offsetCol`.
2046         """
2047         return self._set(offsetCol=value)
2048 
2049     @since("2.0.0")
2050     def setMaxIter(self, value):
2051         """
2052         Sets the value of :py:attr:`maxIter`.
2053         """
2054         return self._set(maxIter=value)
2055 
2056     @since("2.0.0")
2057     def setRegParam(self, value):
2058         """
2059         Sets the value of :py:attr:`regParam`.
2060         """
2061         return self._set(regParam=value)
2062 
2063     @since("2.0.0")
2064     def setTol(self, value):
2065         """
2066         Sets the value of :py:attr:`tol`.
2067         """
2068         return self._set(tol=value)
2069 
2070     @since("2.0.0")
2071     def setFitIntercept(self, value):
2072         """
2073         Sets the value of :py:attr:`fitIntercept`.
2074         """
2075         return self._set(fitIntercept=value)
2076 
2077     @since("2.0.0")
2078     def setWeightCol(self, value):
2079         """
2080         Sets the value of :py:attr:`weightCol`.
2081         """
2082         return self._set(weightCol=value)
2083 
2084     @since("2.0.0")
2085     def setSolver(self, value):
2086         """
2087         Sets the value of :py:attr:`solver`.
2088         """
2089         return self._set(solver=value)
2090 
2091     @since("3.0.0")
2092     def setAggregationDepth(self, value):
2093         """
2094         Sets the value of :py:attr:`aggregationDepth`.
2095         """
2096         return self._set(aggregationDepth=value)
2097 
2098 
2099 class GeneralizedLinearRegressionModel(JavaRegressionModel, _GeneralizedLinearRegressionParams,
2100                                        JavaMLWritable, JavaMLReadable, HasTrainingSummary):
2101     """
2102     Model fitted by :class:`GeneralizedLinearRegression`.
2103 
2104     .. versionadded:: 2.0.0
2105     """
2106 
2107     @since("3.0.0")
2108     def setLinkPredictionCol(self, value):
2109         """
2110         Sets the value of :py:attr:`linkPredictionCol`.
2111         """
2112         return self._set(linkPredictionCol=value)
2113 
2114     @property
2115     @since("2.0.0")
2116     def coefficients(self):
2117         """
2118         Model coefficients.
2119         """
2120         return self._call_java("coefficients")
2121 
2122     @property
2123     @since("2.0.0")
2124     def intercept(self):
2125         """
2126         Model intercept.
2127         """
2128         return self._call_java("intercept")
2129 
2130     @property
2131     @since("2.0.0")
2132     def summary(self):
2133         """
2134         Gets summary (e.g. residuals, deviance, pValues) of model on
2135         training set. An exception is thrown if
2136         `trainingSummary is None`.
2137         """
2138         if self.hasSummary:
2139             return GeneralizedLinearRegressionTrainingSummary(
2140                 super(GeneralizedLinearRegressionModel, self).summary)
2141         else:
2142             raise RuntimeError("No training summary available for this %s" %
2143                                self.__class__.__name__)
2144 
2145     @since("2.0.0")
2146     def evaluate(self, dataset):
2147         """
2148         Evaluates the model on a test dataset.
2149 
2150         :param dataset:
2151           Test dataset to evaluate model on, where dataset is an
2152           instance of :py:class:`pyspark.sql.DataFrame`
2153         """
2154         if not isinstance(dataset, DataFrame):
2155             raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
2156         java_glr_summary = self._call_java("evaluate", dataset)
2157         return GeneralizedLinearRegressionSummary(java_glr_summary)
2158 
2159 
2160 class GeneralizedLinearRegressionSummary(JavaWrapper):
2161     """
2162     Generalized linear regression results evaluated on a dataset.
2163 
2164     .. versionadded:: 2.0.0
2165     """
2166 
2167     @property
2168     @since("2.0.0")
2169     def predictions(self):
2170         """
2171         Predictions output by the model's `transform` method.
2172         """
2173         return self._call_java("predictions")
2174 
2175     @property
2176     @since("2.0.0")
2177     def predictionCol(self):
2178         """
2179         Field in :py:attr:`predictions` which gives the predicted value of each instance.
2180         This is set to a new column name if the original model's `predictionCol` is not set.
2181         """
2182         return self._call_java("predictionCol")
2183 
2184     @property
2185     @since("2.2.0")
2186     def numInstances(self):
2187         """
2188         Number of instances in DataFrame predictions.
2189         """
2190         return self._call_java("numInstances")
2191 
2192     @property
2193     @since("2.0.0")
2194     def rank(self):
2195         """
2196         The numeric rank of the fitted linear model.
2197         """
2198         return self._call_java("rank")
2199 
2200     @property
2201     @since("2.0.0")
2202     def degreesOfFreedom(self):
2203         """
2204         Degrees of freedom.
2205         """
2206         return self._call_java("degreesOfFreedom")
2207 
2208     @property
2209     @since("2.0.0")
2210     def residualDegreeOfFreedom(self):
2211         """
2212         The residual degrees of freedom.
2213         """
2214         return self._call_java("residualDegreeOfFreedom")
2215 
2216     @property
2217     @since("2.0.0")
2218     def residualDegreeOfFreedomNull(self):
2219         """
2220         The residual degrees of freedom for the null model.
2221         """
2222         return self._call_java("residualDegreeOfFreedomNull")
2223 
2224     @since("2.0.0")
2225     def residuals(self, residualsType="deviance"):
2226         """
2227         Get the residuals of the fitted model by type.
2228 
2229         :param residualsType: The type of residuals which should be returned.
2230                               Supported options: deviance (default), pearson, working, and response.
2231         """
2232         return self._call_java("residuals", residualsType)
2233 
2234     @property
2235     @since("2.0.0")
2236     def nullDeviance(self):
2237         """
2238         The deviance for the null model.
2239         """
2240         return self._call_java("nullDeviance")
2241 
2242     @property
2243     @since("2.0.0")
2244     def deviance(self):
2245         """
2246         The deviance for the fitted model.
2247         """
2248         return self._call_java("deviance")
2249 
2250     @property
2251     @since("2.0.0")
2252     def dispersion(self):
2253         """
2254         The dispersion of the fitted model.
2255         It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise
2256         estimated by the residual Pearson's Chi-Squared statistic (which is defined as
2257         sum of the squares of the Pearson residuals) divided by the residual degrees of freedom.
2258         """
2259         return self._call_java("dispersion")
2260 
2261     @property
2262     @since("2.0.0")
2263     def aic(self):
2264         """
2265         Akaike's "An Information Criterion"(AIC) for the fitted model.
2266         """
2267         return self._call_java("aic")
2268 
2269 
2270 @inherit_doc
2271 class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary):
2272     """
2273     Generalized linear regression training results.
2274 
2275     .. versionadded:: 2.0.0
2276     """
2277 
2278     @property
2279     @since("2.0.0")
2280     def numIterations(self):
2281         """
2282         Number of training iterations.
2283         """
2284         return self._call_java("numIterations")
2285 
2286     @property
2287     @since("2.0.0")
2288     def solver(self):
2289         """
2290         The numeric solver used for training.
2291         """
2292         return self._call_java("solver")
2293 
2294     @property
2295     @since("2.0.0")
2296     def coefficientStandardErrors(self):
2297         """
2298         Standard error of estimated coefficients and intercept.
2299 
2300         If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2301         then the last element returned corresponds to the intercept.
2302         """
2303         return self._call_java("coefficientStandardErrors")
2304 
2305     @property
2306     @since("2.0.0")
2307     def tValues(self):
2308         """
2309         T-statistic of estimated coefficients and intercept.
2310 
2311         If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2312         then the last element returned corresponds to the intercept.
2313         """
2314         return self._call_java("tValues")
2315 
2316     @property
2317     @since("2.0.0")
2318     def pValues(self):
2319         """
2320         Two-sided p-value of estimated coefficients and intercept.
2321 
2322         If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
2323         then the last element returned corresponds to the intercept.
2324         """
2325         return self._call_java("pValues")
2326 
2327     def __repr__(self):
2328         return self._call_java("toString")
2329 
2330 
2331 class _FactorizationMachinesParams(_JavaPredictorParams, HasMaxIter, HasStepSize, HasTol,
2332                                    HasSolver, HasSeed, HasFitIntercept, HasRegParam):
2333     """
2334     Params for :py:class:`FMRegressor`, :py:class:`FMRegressionModel`, :py:class:`FMClassifier`
2335     and :py:class:`FMClassifierModel`.
2336 
2337     .. versionadded:: 3.0.0
2338     """
2339 
2340     factorSize = Param(Params._dummy(), "factorSize", "Dimensionality of the factor vectors, " +
2341                        "which are used to get pairwise interactions between variables",
2342                        typeConverter=TypeConverters.toInt)
2343 
2344     fitLinear = Param(Params._dummy(), "fitLinear", "whether to fit linear term (aka 1-way term)",
2345                       typeConverter=TypeConverters.toBoolean)
2346 
2347     miniBatchFraction = Param(Params._dummy(), "miniBatchFraction", "fraction of the input data " +
2348                               "set that should be used for one iteration of gradient descent",
2349                               typeConverter=TypeConverters.toFloat)
2350 
2351     initStd = Param(Params._dummy(), "initStd", "standard deviation of initial coefficients",
2352                     typeConverter=TypeConverters.toFloat)
2353 
2354     solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
2355                    "options: gd, adamW. (Default adamW)", typeConverter=TypeConverters.toString)
2356 
2357     @since("3.0.0")
2358     def getFactorSize(self):
2359         """
2360         Gets the value of factorSize or its default value.
2361         """
2362         return self.getOrDefault(self.factorSize)
2363 
2364     @since("3.0.0")
2365     def getFitLinear(self):
2366         """
2367         Gets the value of fitLinear or its default value.
2368         """
2369         return self.getOrDefault(self.fitLinear)
2370 
2371     @since("3.0.0")
2372     def getMiniBatchFraction(self):
2373         """
2374         Gets the value of miniBatchFraction or its default value.
2375         """
2376         return self.getOrDefault(self.miniBatchFraction)
2377 
2378     @since("3.0.0")
2379     def getInitStd(self):
2380         """
2381         Gets the value of initStd or its default value.
2382         """
2383         return self.getOrDefault(self.initStd)
2384 
2385 
2386 @inherit_doc
2387 class FMRegressor(JavaRegressor, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable):
2388     """
2389     Factorization Machines learning algorithm for regression.
2390 
2391     solver Supports:
2392 
2393     * gd (normal mini-batch gradient descent)
2394     * adamW (default)
2395 
2396     >>> from pyspark.ml.linalg import Vectors
2397     >>> from pyspark.ml.regression import FMRegressor
2398     >>> df = spark.createDataFrame([
2399     ...     (2.0, Vectors.dense(2.0)),
2400     ...     (1.0, Vectors.dense(1.0)),
2401     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
2402     >>>
2403     >>> fm = FMRegressor(factorSize=2)
2404     >>> fm.setSeed(16)
2405     FMRegressor...
2406     >>> model = fm.fit(df)
2407     >>> model.getMaxIter()
2408     100
2409     >>> test0 = spark.createDataFrame([
2410     ...     (Vectors.dense(-2.0),),
2411     ...     (Vectors.dense(0.5),),
2412     ...     (Vectors.dense(1.0),),
2413     ...     (Vectors.dense(4.0),)], ["features"])
2414     >>> model.transform(test0).show(10, False)
2415     +--------+-------------------+
2416     |features|prediction         |
2417     +--------+-------------------+
2418     |[-2.0]  |-1.9989237712341565|
2419     |[0.5]   |0.4956682219523814 |
2420     |[1.0]   |0.994586620589689  |
2421     |[4.0]   |3.9880970124135344 |
2422     +--------+-------------------+
2423     ...
2424     >>> model.intercept
2425     -0.0032501766849261557
2426     >>> model.linear
2427     DenseVector([0.9978])
2428     >>> model.factors
2429     DenseMatrix(1, 2, [0.0173, 0.0021], 1)
2430 
2431     .. versionadded:: 3.0.0
2432     """
2433 
2434     @keyword_only
2435     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2436                  factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2437                  miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2438                  tol=1e-6, solver="adamW", seed=None):
2439         """
2440         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2441                  factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2442                  miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2443                  tol=1e-6, solver="adamW", seed=None)
2444         """
2445         super(FMRegressor, self).__init__()
2446         self._java_obj = self._new_java_obj(
2447             "org.apache.spark.ml.regression.FMRegressor", self.uid)
2448         self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2449                          miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2450                          tol=1e-6, solver="adamW")
2451         kwargs = self._input_kwargs
2452         self.setParams(**kwargs)
2453 
2454     @keyword_only
2455     @since("3.0.0")
2456     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2457                   factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2458                   miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2459                   tol=1e-6, solver="adamW", seed=None):
2460         """
2461         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2462                   factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2463                   miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2464                   tol=1e-6, solver="adamW", seed=None)
2465         Sets Params for FMRegressor.
2466         """
2467         kwargs = self._input_kwargs
2468         return self._set(**kwargs)
2469 
2470     def _create_model(self, java_model):
2471         return FMRegressionModel(java_model)
2472 
2473     @since("3.0.0")
2474     def setFactorSize(self, value):
2475         """
2476         Sets the value of :py:attr:`factorSize`.
2477         """
2478         return self._set(factorSize=value)
2479 
2480     @since("3.0.0")
2481     def setFitLinear(self, value):
2482         """
2483         Sets the value of :py:attr:`fitLinear`.
2484         """
2485         return self._set(fitLinear=value)
2486 
2487     @since("3.0.0")
2488     def setMiniBatchFraction(self, value):
2489         """
2490         Sets the value of :py:attr:`miniBatchFraction`.
2491         """
2492         return self._set(miniBatchFraction=value)
2493 
2494     @since("3.0.0")
2495     def setInitStd(self, value):
2496         """
2497         Sets the value of :py:attr:`initStd`.
2498         """
2499         return self._set(initStd=value)
2500 
2501     @since("3.0.0")
2502     def setMaxIter(self, value):
2503         """
2504         Sets the value of :py:attr:`maxIter`.
2505         """
2506         return self._set(maxIter=value)
2507 
2508     @since("3.0.0")
2509     def setStepSize(self, value):
2510         """
2511         Sets the value of :py:attr:`stepSize`.
2512         """
2513         return self._set(stepSize=value)
2514 
2515     @since("3.0.0")
2516     def setTol(self, value):
2517         """
2518         Sets the value of :py:attr:`tol`.
2519         """
2520         return self._set(tol=value)
2521 
2522     @since("3.0.0")
2523     def setSolver(self, value):
2524         """
2525         Sets the value of :py:attr:`solver`.
2526         """
2527         return self._set(solver=value)
2528 
2529     @since("3.0.0")
2530     def setSeed(self, value):
2531         """
2532         Sets the value of :py:attr:`seed`.
2533         """
2534         return self._set(seed=value)
2535 
2536     @since("3.0.0")
2537     def setFitIntercept(self, value):
2538         """
2539         Sets the value of :py:attr:`fitIntercept`.
2540         """
2541         return self._set(fitIntercept=value)
2542 
2543     @since("3.0.0")
2544     def setRegParam(self, value):
2545         """
2546         Sets the value of :py:attr:`regParam`.
2547         """
2548         return self._set(regParam=value)
2549 
2550 
2551 class FMRegressionModel(JavaRegressionModel, _FactorizationMachinesParams, JavaMLWritable,
2552                         JavaMLReadable):
2553     """
2554     Model fitted by :class:`FMRegressor`.
2555 
2556     .. versionadded:: 3.0.0
2557     """
2558 
2559     @property
2560     @since("3.0.0")
2561     def intercept(self):
2562         """
2563         Model intercept.
2564         """
2565         return self._call_java("intercept")
2566 
2567     @property
2568     @since("3.0.0")
2569     def linear(self):
2570         """
2571         Model linear term.
2572         """
2573         return self._call_java("linear")
2574 
2575     @property
2576     @since("3.0.0")
2577     def factors(self):
2578         """
2579         Model factor term.
2580         """
2581         return self._call_java("factors")
2582 
2583 
2584 if __name__ == "__main__":
2585     import doctest
2586     import pyspark.ml.regression
2587     from pyspark.sql import SparkSession
2588     globs = pyspark.ml.regression.__dict__.copy()
2589     # The small batch size here ensures that we see multiple batches,
2590     # even in these small test examples:
2591     spark = SparkSession.builder\
2592         .master("local[2]")\
2593         .appName("ml.regression tests")\
2594         .getOrCreate()
2595     sc = spark.sparkContext
2596     globs['sc'] = sc
2597     globs['spark'] = spark
2598     import tempfile
2599     temp_path = tempfile.mkdtemp()
2600     globs['temp_path'] = temp_path
2601     try:
2602         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
2603         spark.stop()
2604     finally:
2605         from shutil import rmtree
2606         try:
2607             rmtree(temp_path)
2608         except OSError:
2609             pass
2610     if failure_count:
2611         sys.exit(-1)