Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import operator
0019 import sys
0020 from multiprocessing.pool import ThreadPool
0021 
0022 from pyspark import since, keyword_only
0023 from pyspark.ml import Estimator, Model
0024 from pyspark.ml.param.shared import *
0025 from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
0026     _TreeEnsembleModel, _RandomForestParams, _GBTParams, \
0027     _HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
0028 from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
0029 from pyspark.ml.util import *
0030 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
0031     JavaPredictor, _JavaPredictorParams, JavaPredictionModel, JavaWrapper
0032 from pyspark.ml.common import inherit_doc, _java2py, _py2java
0033 from pyspark.ml.linalg import Vectors
0034 from pyspark.sql import DataFrame
0035 from pyspark.sql.functions import udf, when
0036 from pyspark.sql.types import ArrayType, DoubleType
0037 from pyspark.storagelevel import StorageLevel
0038 
0039 __all__ = ['LinearSVC', 'LinearSVCModel',
0040            'LogisticRegression', 'LogisticRegressionModel',
0041            'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
0042            'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
0043            'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
0044            'GBTClassifier', 'GBTClassificationModel',
0045            'RandomForestClassifier', 'RandomForestClassificationModel',
0046            'NaiveBayes', 'NaiveBayesModel',
0047            'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
0048            'OneVsRest', 'OneVsRestModel',
0049            'FMClassifier', 'FMClassificationModel']
0050 
0051 
0052 class _JavaClassifierParams(HasRawPredictionCol, _JavaPredictorParams):
0053     """
0054     Java Classifier Params for classification tasks.
0055 
0056     .. versionadded:: 3.0.0
0057     """
0058     pass
0059 
0060 
0061 @inherit_doc
0062 class JavaClassifier(JavaPredictor, _JavaClassifierParams):
0063     """
0064     Java Classifier for classification tasks.
0065     Classes are indexed {0, 1, ..., numClasses - 1}.
0066     """
0067 
0068     @since("3.0.0")
0069     def setRawPredictionCol(self, value):
0070         """
0071         Sets the value of :py:attr:`rawPredictionCol`.
0072         """
0073         return self._set(rawPredictionCol=value)
0074 
0075 
0076 @inherit_doc
0077 class JavaClassificationModel(JavaPredictionModel, _JavaClassifierParams):
0078     """
0079     Java Model produced by a ``Classifier``.
0080     Classes are indexed {0, 1, ..., numClasses - 1}.
0081     To be mixed in with class:`pyspark.ml.JavaModel`
0082     """
0083 
0084     @since("3.0.0")
0085     def setRawPredictionCol(self, value):
0086         """
0087         Sets the value of :py:attr:`rawPredictionCol`.
0088         """
0089         return self._set(rawPredictionCol=value)
0090 
0091     @property
0092     @since("2.1.0")
0093     def numClasses(self):
0094         """
0095         Number of classes (values which the label can take).
0096         """
0097         return self._call_java("numClasses")
0098 
0099     @since("3.0.0")
0100     def predictRaw(self, value):
0101         """
0102         Raw prediction for each possible label.
0103         """
0104         return self._call_java("predictRaw", value)
0105 
0106 
0107 class _JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _JavaClassifierParams):
0108     """
0109     Params for :py:class:`JavaProbabilisticClassifier` and
0110     :py:class:`JavaProbabilisticClassificationModel`.
0111 
0112     .. versionadded:: 3.0.0
0113     """
0114     pass
0115 
0116 
0117 @inherit_doc
0118 class JavaProbabilisticClassifier(JavaClassifier, _JavaProbabilisticClassifierParams):
0119     """
0120     Java Probabilistic Classifier for classification tasks.
0121     """
0122 
0123     @since("3.0.0")
0124     def setProbabilityCol(self, value):
0125         """
0126         Sets the value of :py:attr:`probabilityCol`.
0127         """
0128         return self._set(probabilityCol=value)
0129 
0130     @since("3.0.0")
0131     def setThresholds(self, value):
0132         """
0133         Sets the value of :py:attr:`thresholds`.
0134         """
0135         return self._set(thresholds=value)
0136 
0137 
0138 @inherit_doc
0139 class JavaProbabilisticClassificationModel(JavaClassificationModel,
0140                                            _JavaProbabilisticClassifierParams):
0141     """
0142     Java Model produced by a ``ProbabilisticClassifier``.
0143     """
0144 
0145     @since("3.0.0")
0146     def setProbabilityCol(self, value):
0147         """
0148         Sets the value of :py:attr:`probabilityCol`.
0149         """
0150         return self._set(probabilityCol=value)
0151 
0152     @since("3.0.0")
0153     def setThresholds(self, value):
0154         """
0155         Sets the value of :py:attr:`thresholds`.
0156         """
0157         return self._set(thresholds=value)
0158 
0159     @since("3.0.0")
0160     def predictProbability(self, value):
0161         """
0162         Predict the probability of each class given the features.
0163         """
0164         return self._call_java("predictProbability", value)
0165 
0166 
0167 class _LinearSVCParams(_JavaClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
0168                        HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold):
0169     """
0170     Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.
0171 
0172     .. versionadded:: 3.0.0
0173     """
0174 
0175     threshold = Param(Params._dummy(), "threshold",
0176                       "The threshold in binary classification applied to the linear model"
0177                       " prediction.  This threshold can be any real number, where Inf will make"
0178                       " all predictions 0.0 and -Inf will make all predictions 1.0.",
0179                       typeConverter=TypeConverters.toFloat)
0180 
0181 
0182 @inherit_doc
0183 class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
0184     """
0185     `Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_
0186 
0187     This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
0188     Only supports L2 regularization currently.
0189 
0190     >>> from pyspark.sql import Row
0191     >>> from pyspark.ml.linalg import Vectors
0192     >>> df = sc.parallelize([
0193     ...     Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
0194     ...     Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
0195     >>> svm = LinearSVC()
0196     >>> svm.getMaxIter()
0197     100
0198     >>> svm.setMaxIter(5)
0199     LinearSVC...
0200     >>> svm.getMaxIter()
0201     5
0202     >>> svm.getRegParam()
0203     0.0
0204     >>> svm.setRegParam(0.01)
0205     LinearSVC...
0206     >>> svm.getRegParam()
0207     0.01
0208     >>> model = svm.fit(df)
0209     >>> model.setPredictionCol("newPrediction")
0210     LinearSVCModel...
0211     >>> model.getPredictionCol()
0212     'newPrediction'
0213     >>> model.setThreshold(0.5)
0214     LinearSVCModel...
0215     >>> model.getThreshold()
0216     0.5
0217     >>> model.coefficients
0218     DenseVector([0.0, -0.2792, -0.1833])
0219     >>> model.intercept
0220     1.0206118982229047
0221     >>> model.numClasses
0222     2
0223     >>> model.numFeatures
0224     3
0225     >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
0226     >>> model.predict(test0.head().features)
0227     1.0
0228     >>> model.predictRaw(test0.head().features)
0229     DenseVector([-1.4831, 1.4831])
0230     >>> result = model.transform(test0).head()
0231     >>> result.newPrediction
0232     1.0
0233     >>> result.rawPrediction
0234     DenseVector([-1.4831, 1.4831])
0235     >>> svm_path = temp_path + "/svm"
0236     >>> svm.save(svm_path)
0237     >>> svm2 = LinearSVC.load(svm_path)
0238     >>> svm2.getMaxIter()
0239     5
0240     >>> model_path = temp_path + "/svm_model"
0241     >>> model.save(model_path)
0242     >>> model2 = LinearSVCModel.load(model_path)
0243     >>> model.coefficients[0] == model2.coefficients[0]
0244     True
0245     >>> model.intercept == model2.intercept
0246     True
0247 
0248     .. versionadded:: 2.2.0
0249     """
0250 
0251     @keyword_only
0252     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0253                  maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
0254                  fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
0255                  aggregationDepth=2):
0256         """
0257         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0258                  maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
0259                  fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
0260                  aggregationDepth=2):
0261         """
0262         super(LinearSVC, self).__init__()
0263         self._java_obj = self._new_java_obj(
0264             "org.apache.spark.ml.classification.LinearSVC", self.uid)
0265         self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True,
0266                          standardization=True, threshold=0.0, aggregationDepth=2)
0267         kwargs = self._input_kwargs
0268         self.setParams(**kwargs)
0269 
0270     @keyword_only
0271     @since("2.2.0")
0272     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0273                   maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
0274                   fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
0275                   aggregationDepth=2):
0276         """
0277         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0278                   maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
0279                   fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
0280                   aggregationDepth=2):
0281         Sets params for Linear SVM Classifier.
0282         """
0283         kwargs = self._input_kwargs
0284         return self._set(**kwargs)
0285 
0286     def _create_model(self, java_model):
0287         return LinearSVCModel(java_model)
0288 
0289     @since("2.2.0")
0290     def setMaxIter(self, value):
0291         """
0292         Sets the value of :py:attr:`maxIter`.
0293         """
0294         return self._set(maxIter=value)
0295 
0296     @since("2.2.0")
0297     def setRegParam(self, value):
0298         """
0299         Sets the value of :py:attr:`regParam`.
0300         """
0301         return self._set(regParam=value)
0302 
0303     @since("2.2.0")
0304     def setTol(self, value):
0305         """
0306         Sets the value of :py:attr:`tol`.
0307         """
0308         return self._set(tol=value)
0309 
0310     @since("2.2.0")
0311     def setFitIntercept(self, value):
0312         """
0313         Sets the value of :py:attr:`fitIntercept`.
0314         """
0315         return self._set(fitIntercept=value)
0316 
0317     @since("2.2.0")
0318     def setStandardization(self, value):
0319         """
0320         Sets the value of :py:attr:`standardization`.
0321         """
0322         return self._set(standardization=value)
0323 
0324     @since("2.2.0")
0325     def setThreshold(self, value):
0326         """
0327         Sets the value of :py:attr:`threshold`.
0328         """
0329         return self._set(threshold=value)
0330 
0331     @since("2.2.0")
0332     def setWeightCol(self, value):
0333         """
0334         Sets the value of :py:attr:`weightCol`.
0335         """
0336         return self._set(weightCol=value)
0337 
0338     @since("2.2.0")
0339     def setAggregationDepth(self, value):
0340         """
0341         Sets the value of :py:attr:`aggregationDepth`.
0342         """
0343         return self._set(aggregationDepth=value)
0344 
0345 
0346 class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
0347     """
0348     Model fitted by LinearSVC.
0349 
0350     .. versionadded:: 2.2.0
0351     """
0352 
0353     @since("3.0.0")
0354     def setThreshold(self, value):
0355         """
0356         Sets the value of :py:attr:`threshold`.
0357         """
0358         return self._set(threshold=value)
0359 
0360     @property
0361     @since("2.2.0")
0362     def coefficients(self):
0363         """
0364         Model coefficients of Linear SVM Classifier.
0365         """
0366         return self._call_java("coefficients")
0367 
0368     @property
0369     @since("2.2.0")
0370     def intercept(self):
0371         """
0372         Model intercept of Linear SVM Classifier.
0373         """
0374         return self._call_java("intercept")
0375 
0376 
0377 class _LogisticRegressionParams(_JavaProbabilisticClassifierParams, HasRegParam,
0378                                 HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol,
0379                                 HasStandardization, HasWeightCol, HasAggregationDepth,
0380                                 HasThreshold):
0381     """
0382     Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.
0383 
0384     .. versionadded:: 3.0.0
0385     """
0386 
0387     threshold = Param(Params._dummy(), "threshold",
0388                       "Threshold in binary classification prediction, in range [0, 1]." +
0389                       " If threshold and thresholds are both set, they must match." +
0390                       "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
0391                       typeConverter=TypeConverters.toFloat)
0392 
0393     family = Param(Params._dummy(), "family",
0394                    "The name of family which is a description of the label distribution to " +
0395                    "be used in the model. Supported options: auto, binomial, multinomial",
0396                    typeConverter=TypeConverters.toString)
0397 
0398     lowerBoundsOnCoefficients = Param(Params._dummy(), "lowerBoundsOnCoefficients",
0399                                       "The lower bounds on coefficients if fitting under bound "
0400                                       "constrained optimization. The bound matrix must be "
0401                                       "compatible with the shape "
0402                                       "(1, number of features) for binomial regression, or "
0403                                       "(number of classes, number of features) "
0404                                       "for multinomial regression.",
0405                                       typeConverter=TypeConverters.toMatrix)
0406 
0407     upperBoundsOnCoefficients = Param(Params._dummy(), "upperBoundsOnCoefficients",
0408                                       "The upper bounds on coefficients if fitting under bound "
0409                                       "constrained optimization. The bound matrix must be "
0410                                       "compatible with the shape "
0411                                       "(1, number of features) for binomial regression, or "
0412                                       "(number of classes, number of features) "
0413                                       "for multinomial regression.",
0414                                       typeConverter=TypeConverters.toMatrix)
0415 
0416     lowerBoundsOnIntercepts = Param(Params._dummy(), "lowerBoundsOnIntercepts",
0417                                     "The lower bounds on intercepts if fitting under bound "
0418                                     "constrained optimization. The bounds vector size must be"
0419                                     "equal with 1 for binomial regression, or the number of"
0420                                     "lasses for multinomial regression.",
0421                                     typeConverter=TypeConverters.toVector)
0422 
0423     upperBoundsOnIntercepts = Param(Params._dummy(), "upperBoundsOnIntercepts",
0424                                     "The upper bounds on intercepts if fitting under bound "
0425                                     "constrained optimization. The bound vector size must be "
0426                                     "equal with 1 for binomial regression, or the number of "
0427                                     "classes for multinomial regression.",
0428                                     typeConverter=TypeConverters.toVector)
0429 
0430     @since("1.4.0")
0431     def setThreshold(self, value):
0432         """
0433         Sets the value of :py:attr:`threshold`.
0434         Clears value of :py:attr:`thresholds` if it has been set.
0435         """
0436         self._set(threshold=value)
0437         self.clear(self.thresholds)
0438         return self
0439 
0440     @since("1.4.0")
0441     def getThreshold(self):
0442         """
0443         Get threshold for binary classification.
0444 
0445         If :py:attr:`thresholds` is set with length 2 (i.e., binary classification),
0446         this returns the equivalent threshold:
0447         :math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`.
0448         Otherwise, returns :py:attr:`threshold` if set or its default value if unset.
0449         """
0450         self._checkThresholdConsistency()
0451         if self.isSet(self.thresholds):
0452             ts = self.getOrDefault(self.thresholds)
0453             if len(ts) != 2:
0454                 raise ValueError("Logistic Regression getThreshold only applies to" +
0455                                  " binary classification, but thresholds has length != 2." +
0456                                  "  thresholds: " + ",".join(ts))
0457             return 1.0/(1.0 + ts[0]/ts[1])
0458         else:
0459             return self.getOrDefault(self.threshold)
0460 
0461     @since("1.5.0")
0462     def setThresholds(self, value):
0463         """
0464         Sets the value of :py:attr:`thresholds`.
0465         Clears value of :py:attr:`threshold` if it has been set.
0466         """
0467         self._set(thresholds=value)
0468         self.clear(self.threshold)
0469         return self
0470 
0471     @since("1.5.0")
0472     def getThresholds(self):
0473         """
0474         If :py:attr:`thresholds` is set, return its value.
0475         Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
0476         classification: (1-threshold, threshold).
0477         If neither are set, throw an error.
0478         """
0479         self._checkThresholdConsistency()
0480         if not self.isSet(self.thresholds) and self.isSet(self.threshold):
0481             t = self.getOrDefault(self.threshold)
0482             return [1.0-t, t]
0483         else:
0484             return self.getOrDefault(self.thresholds)
0485 
0486     def _checkThresholdConsistency(self):
0487         if self.isSet(self.threshold) and self.isSet(self.thresholds):
0488             ts = self.getOrDefault(self.thresholds)
0489             if len(ts) != 2:
0490                 raise ValueError("Logistic Regression getThreshold only applies to" +
0491                                  " binary classification, but thresholds has length != 2." +
0492                                  " thresholds: {0}".format(str(ts)))
0493             t = 1.0/(1.0 + ts[0]/ts[1])
0494             t2 = self.getOrDefault(self.threshold)
0495             if abs(t2 - t) >= 1E-5:
0496                 raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
0497                                  " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
0498 
0499     @since("2.1.0")
0500     def getFamily(self):
0501         """
0502         Gets the value of :py:attr:`family` or its default value.
0503         """
0504         return self.getOrDefault(self.family)
0505 
0506     @since("2.3.0")
0507     def getLowerBoundsOnCoefficients(self):
0508         """
0509         Gets the value of :py:attr:`lowerBoundsOnCoefficients`
0510         """
0511         return self.getOrDefault(self.lowerBoundsOnCoefficients)
0512 
0513     @since("2.3.0")
0514     def getUpperBoundsOnCoefficients(self):
0515         """
0516         Gets the value of :py:attr:`upperBoundsOnCoefficients`
0517         """
0518         return self.getOrDefault(self.upperBoundsOnCoefficients)
0519 
0520     @since("2.3.0")
0521     def getLowerBoundsOnIntercepts(self):
0522         """
0523         Gets the value of :py:attr:`lowerBoundsOnIntercepts`
0524         """
0525         return self.getOrDefault(self.lowerBoundsOnIntercepts)
0526 
0527     @since("2.3.0")
0528     def getUpperBoundsOnIntercepts(self):
0529         """
0530         Gets the value of :py:attr:`upperBoundsOnIntercepts`
0531         """
0532         return self.getOrDefault(self.upperBoundsOnIntercepts)
0533 
0534 
0535 @inherit_doc
0536 class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable,
0537                          JavaMLReadable):
0538     """
0539     Logistic regression.
0540     This class supports multinomial logistic (softmax) and binomial logistic regression.
0541 
0542     >>> from pyspark.sql import Row
0543     >>> from pyspark.ml.linalg import Vectors
0544     >>> bdf = sc.parallelize([
0545     ...     Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),
0546     ...     Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),
0547     ...     Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),
0548     ...     Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
0549     >>> blor = LogisticRegression(weightCol="weight")
0550     >>> blor.getRegParam()
0551     0.0
0552     >>> blor.setRegParam(0.01)
0553     LogisticRegression...
0554     >>> blor.getRegParam()
0555     0.01
0556     >>> blor.setMaxIter(10)
0557     LogisticRegression...
0558     >>> blor.getMaxIter()
0559     10
0560     >>> blor.clear(blor.maxIter)
0561     >>> blorModel = blor.fit(bdf)
0562     >>> blorModel.setFeaturesCol("features")
0563     LogisticRegressionModel...
0564     >>> blorModel.setProbabilityCol("newProbability")
0565     LogisticRegressionModel...
0566     >>> blorModel.getProbabilityCol()
0567     'newProbability'
0568     >>> blorModel.setThreshold(0.1)
0569     LogisticRegressionModel...
0570     >>> blorModel.getThreshold()
0571     0.1
0572     >>> blorModel.coefficients
0573     DenseVector([-1.080..., -0.646...])
0574     >>> blorModel.intercept
0575     3.112...
0576     >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
0577     >>> mdf = spark.read.format("libsvm").load(data_path)
0578     >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial")
0579     >>> mlorModel = mlor.fit(mdf)
0580     >>> mlorModel.coefficientMatrix
0581     SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)
0582     >>> mlorModel.interceptVector
0583     DenseVector([0.04..., -0.42..., 0.37...])
0584     >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
0585     >>> blorModel.predict(test0.head().features)
0586     1.0
0587     >>> blorModel.predictRaw(test0.head().features)
0588     DenseVector([-3.54..., 3.54...])
0589     >>> blorModel.predictProbability(test0.head().features)
0590     DenseVector([0.028, 0.972])
0591     >>> result = blorModel.transform(test0).head()
0592     >>> result.prediction
0593     1.0
0594     >>> result.newProbability
0595     DenseVector([0.02..., 0.97...])
0596     >>> result.rawPrediction
0597     DenseVector([-3.54..., 3.54...])
0598     >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
0599     >>> blorModel.transform(test1).head().prediction
0600     1.0
0601     >>> blor.setParams("vector")
0602     Traceback (most recent call last):
0603         ...
0604     TypeError: Method setParams forces keyword arguments.
0605     >>> lr_path = temp_path + "/lr"
0606     >>> blor.save(lr_path)
0607     >>> lr2 = LogisticRegression.load(lr_path)
0608     >>> lr2.getRegParam()
0609     0.01
0610     >>> model_path = temp_path + "/lr_model"
0611     >>> blorModel.save(model_path)
0612     >>> model2 = LogisticRegressionModel.load(model_path)
0613     >>> blorModel.coefficients[0] == model2.coefficients[0]
0614     True
0615     >>> blorModel.intercept == model2.intercept
0616     True
0617     >>> model2
0618     LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2
0619 
0620     .. versionadded:: 1.3.0
0621     """
0622 
0623     @keyword_only
0624     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0625                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0626                  threshold=0.5, thresholds=None, probabilityCol="probability",
0627                  rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
0628                  aggregationDepth=2, family="auto",
0629                  lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
0630                  lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0631 
0632         """
0633         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0634                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0635                  threshold=0.5, thresholds=None, probabilityCol="probability", \
0636                  rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
0637                  aggregationDepth=2, family="auto", \
0638                  lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
0639                  lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0640         If the threshold and thresholds Params are both set, they must be equivalent.
0641         """
0642         super(LogisticRegression, self).__init__()
0643         self._java_obj = self._new_java_obj(
0644             "org.apache.spark.ml.classification.LogisticRegression", self.uid)
0645         self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto")
0646         kwargs = self._input_kwargs
0647         self.setParams(**kwargs)
0648         self._checkThresholdConsistency()
0649 
0650     @keyword_only
0651     @since("1.3.0")
0652     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
0653                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
0654                   threshold=0.5, thresholds=None, probabilityCol="probability",
0655                   rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
0656                   aggregationDepth=2, family="auto",
0657                   lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
0658                   lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0659         """
0660         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
0661                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
0662                   threshold=0.5, thresholds=None, probabilityCol="probability", \
0663                   rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
0664                   aggregationDepth=2, family="auto", \
0665                   lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
0666                   lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
0667         Sets params for logistic regression.
0668         If the threshold and thresholds Params are both set, they must be equivalent.
0669         """
0670         kwargs = self._input_kwargs
0671         self._set(**kwargs)
0672         self._checkThresholdConsistency()
0673         return self
0674 
0675     def _create_model(self, java_model):
0676         return LogisticRegressionModel(java_model)
0677 
0678     @since("2.1.0")
0679     def setFamily(self, value):
0680         """
0681         Sets the value of :py:attr:`family`.
0682         """
0683         return self._set(family=value)
0684 
0685     @since("2.3.0")
0686     def setLowerBoundsOnCoefficients(self, value):
0687         """
0688         Sets the value of :py:attr:`lowerBoundsOnCoefficients`
0689         """
0690         return self._set(lowerBoundsOnCoefficients=value)
0691 
0692     @since("2.3.0")
0693     def setUpperBoundsOnCoefficients(self, value):
0694         """
0695         Sets the value of :py:attr:`upperBoundsOnCoefficients`
0696         """
0697         return self._set(upperBoundsOnCoefficients=value)
0698 
0699     @since("2.3.0")
0700     def setLowerBoundsOnIntercepts(self, value):
0701         """
0702         Sets the value of :py:attr:`lowerBoundsOnIntercepts`
0703         """
0704         return self._set(lowerBoundsOnIntercepts=value)
0705 
0706     @since("2.3.0")
0707     def setUpperBoundsOnIntercepts(self, value):
0708         """
0709         Sets the value of :py:attr:`upperBoundsOnIntercepts`
0710         """
0711         return self._set(upperBoundsOnIntercepts=value)
0712 
0713     def setMaxIter(self, value):
0714         """
0715         Sets the value of :py:attr:`maxIter`.
0716         """
0717         return self._set(maxIter=value)
0718 
0719     def setRegParam(self, value):
0720         """
0721         Sets the value of :py:attr:`regParam`.
0722         """
0723         return self._set(regParam=value)
0724 
0725     def setTol(self, value):
0726         """
0727         Sets the value of :py:attr:`tol`.
0728         """
0729         return self._set(tol=value)
0730 
0731     def setElasticNetParam(self, value):
0732         """
0733         Sets the value of :py:attr:`elasticNetParam`.
0734         """
0735         return self._set(elasticNetParam=value)
0736 
0737     def setFitIntercept(self, value):
0738         """
0739         Sets the value of :py:attr:`fitIntercept`.
0740         """
0741         return self._set(fitIntercept=value)
0742 
0743     def setStandardization(self, value):
0744         """
0745         Sets the value of :py:attr:`standardization`.
0746         """
0747         return self._set(standardization=value)
0748 
0749     def setWeightCol(self, value):
0750         """
0751         Sets the value of :py:attr:`weightCol`.
0752         """
0753         return self._set(weightCol=value)
0754 
0755     def setAggregationDepth(self, value):
0756         """
0757         Sets the value of :py:attr:`aggregationDepth`.
0758         """
0759         return self._set(aggregationDepth=value)
0760 
0761 
0762 class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticRegressionParams,
0763                               JavaMLWritable, JavaMLReadable, HasTrainingSummary):
0764     """
0765     Model fitted by LogisticRegression.
0766 
0767     .. versionadded:: 1.3.0
0768     """
0769 
0770     @property
0771     @since("2.0.0")
0772     def coefficients(self):
0773         """
0774         Model coefficients of binomial logistic regression.
0775         An exception is thrown in the case of multinomial logistic regression.
0776         """
0777         return self._call_java("coefficients")
0778 
0779     @property
0780     @since("1.4.0")
0781     def intercept(self):
0782         """
0783         Model intercept of binomial logistic regression.
0784         An exception is thrown in the case of multinomial logistic regression.
0785         """
0786         return self._call_java("intercept")
0787 
0788     @property
0789     @since("2.1.0")
0790     def coefficientMatrix(self):
0791         """
0792         Model coefficients.
0793         """
0794         return self._call_java("coefficientMatrix")
0795 
0796     @property
0797     @since("2.1.0")
0798     def interceptVector(self):
0799         """
0800         Model intercept.
0801         """
0802         return self._call_java("interceptVector")
0803 
0804     @property
0805     @since("2.0.0")
0806     def summary(self):
0807         """
0808         Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
0809         trained on the training set. An exception is thrown if `trainingSummary is None`.
0810         """
0811         if self.hasSummary:
0812             if self.numClasses <= 2:
0813                 return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
0814                                                                      self).summary)
0815             else:
0816                 return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
0817                                                                self).summary)
0818         else:
0819             raise RuntimeError("No training summary available for this %s" %
0820                                self.__class__.__name__)
0821 
0822     @since("2.0.0")
0823     def evaluate(self, dataset):
0824         """
0825         Evaluates the model on a test dataset.
0826 
0827         :param dataset:
0828           Test dataset to evaluate model on, where dataset is an
0829           instance of :py:class:`pyspark.sql.DataFrame`
0830         """
0831         if not isinstance(dataset, DataFrame):
0832             raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
0833         java_blr_summary = self._call_java("evaluate", dataset)
0834         if self.numClasses <= 2:
0835             return BinaryLogisticRegressionSummary(java_blr_summary)
0836         else:
0837             return LogisticRegressionSummary(java_blr_summary)
0838 
0839 
0840 class LogisticRegressionSummary(JavaWrapper):
0841     """
0842     Abstraction for Logistic Regression Results for a given model.
0843 
0844     .. versionadded:: 2.0.0
0845     """
0846 
0847     @property
0848     @since("2.0.0")
0849     def predictions(self):
0850         """
0851         Dataframe outputted by the model's `transform` method.
0852         """
0853         return self._call_java("predictions")
0854 
0855     @property
0856     @since("2.0.0")
0857     def probabilityCol(self):
0858         """
0859         Field in "predictions" which gives the probability
0860         of each class as a vector.
0861         """
0862         return self._call_java("probabilityCol")
0863 
0864     @property
0865     @since("2.3.0")
0866     def predictionCol(self):
0867         """
0868         Field in "predictions" which gives the prediction of each class.
0869         """
0870         return self._call_java("predictionCol")
0871 
0872     @property
0873     @since("2.0.0")
0874     def labelCol(self):
0875         """
0876         Field in "predictions" which gives the true label of each
0877         instance.
0878         """
0879         return self._call_java("labelCol")
0880 
0881     @property
0882     @since("2.0.0")
0883     def featuresCol(self):
0884         """
0885         Field in "predictions" which gives the features of each instance
0886         as a vector.
0887         """
0888         return self._call_java("featuresCol")
0889 
0890     @property
0891     @since("2.3.0")
0892     def labels(self):
0893         """
0894         Returns the sequence of labels in ascending order. This order matches the order used
0895         in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
0896 
0897         Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
0898         training set is missing a label, then all of the arrays over labels
0899         (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
0900         expected numClasses.
0901         """
0902         return self._call_java("labels")
0903 
0904     @property
0905     @since("2.3.0")
0906     def truePositiveRateByLabel(self):
0907         """
0908         Returns true positive rate for each label (category).
0909         """
0910         return self._call_java("truePositiveRateByLabel")
0911 
0912     @property
0913     @since("2.3.0")
0914     def falsePositiveRateByLabel(self):
0915         """
0916         Returns false positive rate for each label (category).
0917         """
0918         return self._call_java("falsePositiveRateByLabel")
0919 
0920     @property
0921     @since("2.3.0")
0922     def precisionByLabel(self):
0923         """
0924         Returns precision for each label (category).
0925         """
0926         return self._call_java("precisionByLabel")
0927 
0928     @property
0929     @since("2.3.0")
0930     def recallByLabel(self):
0931         """
0932         Returns recall for each label (category).
0933         """
0934         return self._call_java("recallByLabel")
0935 
0936     @since("2.3.0")
0937     def fMeasureByLabel(self, beta=1.0):
0938         """
0939         Returns f-measure for each label (category).
0940         """
0941         return self._call_java("fMeasureByLabel", beta)
0942 
0943     @property
0944     @since("2.3.0")
0945     def accuracy(self):
0946         """
0947         Returns accuracy.
0948         (equals to the total number of correctly classified instances
0949         out of the total number of instances.)
0950         """
0951         return self._call_java("accuracy")
0952 
0953     @property
0954     @since("2.3.0")
0955     def weightedTruePositiveRate(self):
0956         """
0957         Returns weighted true positive rate.
0958         (equals to precision, recall and f-measure)
0959         """
0960         return self._call_java("weightedTruePositiveRate")
0961 
0962     @property
0963     @since("2.3.0")
0964     def weightedFalsePositiveRate(self):
0965         """
0966         Returns weighted false positive rate.
0967         """
0968         return self._call_java("weightedFalsePositiveRate")
0969 
0970     @property
0971     @since("2.3.0")
0972     def weightedRecall(self):
0973         """
0974         Returns weighted averaged recall.
0975         (equals to precision, recall and f-measure)
0976         """
0977         return self._call_java("weightedRecall")
0978 
0979     @property
0980     @since("2.3.0")
0981     def weightedPrecision(self):
0982         """
0983         Returns weighted averaged precision.
0984         """
0985         return self._call_java("weightedPrecision")
0986 
0987     @since("2.3.0")
0988     def weightedFMeasure(self, beta=1.0):
0989         """
0990         Returns weighted averaged f-measure.
0991         """
0992         return self._call_java("weightedFMeasure", beta)
0993 
0994 
0995 @inherit_doc
0996 class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
0997     """
0998     Abstraction for multinomial Logistic Regression Training results.
0999     Currently, the training summary ignores the training weights except
1000     for the objective trace.
1001 
1002     .. versionadded:: 2.0.0
1003     """
1004 
1005     @property
1006     @since("2.0.0")
1007     def objectiveHistory(self):
1008         """
1009         Objective function (scaled loss + regularization) at each
1010         iteration.
1011         """
1012         return self._call_java("objectiveHistory")
1013 
1014     @property
1015     @since("2.0.0")
1016     def totalIterations(self):
1017         """
1018         Number of training iterations until termination.
1019         """
1020         return self._call_java("totalIterations")
1021 
1022 
1023 @inherit_doc
1024 class BinaryLogisticRegressionSummary(LogisticRegressionSummary):
1025     """
1026     Binary Logistic regression results for a given model.
1027 
1028     .. versionadded:: 2.0.0
1029     """
1030 
1031     @property
1032     @since("2.0.0")
1033     def roc(self):
1034         """
1035         Returns the receiver operating characteristic (ROC) curve,
1036         which is a Dataframe having two fields (FPR, TPR) with
1037         (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
1038 
1039         .. seealso:: `Wikipedia reference
1040             <http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
1041 
1042         .. note:: This ignores instance weights (setting all to 1.0) from
1043             `LogisticRegression.weightCol`. This will change in later Spark
1044             versions.
1045         """
1046         return self._call_java("roc")
1047 
1048     @property
1049     @since("2.0.0")
1050     def areaUnderROC(self):
1051         """
1052         Computes the area under the receiver operating characteristic
1053         (ROC) curve.
1054 
1055         .. note:: This ignores instance weights (setting all to 1.0) from
1056             `LogisticRegression.weightCol`. This will change in later Spark
1057             versions.
1058         """
1059         return self._call_java("areaUnderROC")
1060 
1061     @property
1062     @since("2.0.0")
1063     def pr(self):
1064         """
1065         Returns the precision-recall curve, which is a Dataframe
1066         containing two fields recall, precision with (0.0, 1.0) prepended
1067         to it.
1068 
1069         .. note:: This ignores instance weights (setting all to 1.0) from
1070             `LogisticRegression.weightCol`. This will change in later Spark
1071             versions.
1072         """
1073         return self._call_java("pr")
1074 
1075     @property
1076     @since("2.0.0")
1077     def fMeasureByThreshold(self):
1078         """
1079         Returns a dataframe with two fields (threshold, F-Measure) curve
1080         with beta = 1.0.
1081 
1082         .. note:: This ignores instance weights (setting all to 1.0) from
1083             `LogisticRegression.weightCol`. This will change in later Spark
1084             versions.
1085         """
1086         return self._call_java("fMeasureByThreshold")
1087 
1088     @property
1089     @since("2.0.0")
1090     def precisionByThreshold(self):
1091         """
1092         Returns a dataframe with two fields (threshold, precision) curve.
1093         Every possible probability obtained in transforming the dataset
1094         are used as thresholds used in calculating the precision.
1095 
1096         .. note:: This ignores instance weights (setting all to 1.0) from
1097             `LogisticRegression.weightCol`. This will change in later Spark
1098             versions.
1099         """
1100         return self._call_java("precisionByThreshold")
1101 
1102     @property
1103     @since("2.0.0")
1104     def recallByThreshold(self):
1105         """
1106         Returns a dataframe with two fields (threshold, recall) curve.
1107         Every possible probability obtained in transforming the dataset
1108         are used as thresholds used in calculating the recall.
1109 
1110         .. note:: This ignores instance weights (setting all to 1.0) from
1111             `LogisticRegression.weightCol`. This will change in later Spark
1112             versions.
1113         """
1114         return self._call_java("recallByThreshold")
1115 
1116 
1117 @inherit_doc
1118 class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
1119                                               LogisticRegressionTrainingSummary):
1120     """
1121     Binary Logistic regression training results for a given model.
1122 
1123     .. versionadded:: 2.0.0
1124     """
1125     pass
1126 
1127 
1128 @inherit_doc
1129 class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
1130     """
1131     Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.
1132     """
1133     pass
1134 
1135 
1136 @inherit_doc
1137 class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
1138                              JavaMLWritable, JavaMLReadable):
1139     """
1140     `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
1141     learning algorithm for classification.
1142     It supports both binary and multiclass labels, as well as both continuous and categorical
1143     features.
1144 
1145     >>> from pyspark.ml.linalg import Vectors
1146     >>> from pyspark.ml.feature import StringIndexer
1147     >>> df = spark.createDataFrame([
1148     ...     (1.0, Vectors.dense(1.0)),
1149     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1150     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1151     >>> si_model = stringIndexer.fit(df)
1152     >>> td = si_model.transform(df)
1153     >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")
1154     >>> model = dt.fit(td)
1155     >>> model.getLabelCol()
1156     'indexed'
1157     >>> model.setFeaturesCol("features")
1158     DecisionTreeClassificationModel...
1159     >>> model.numNodes
1160     3
1161     >>> model.depth
1162     1
1163     >>> model.featureImportances
1164     SparseVector(1, {0: 1.0})
1165     >>> model.numFeatures
1166     1
1167     >>> model.numClasses
1168     2
1169     >>> print(model.toDebugString)
1170     DecisionTreeClassificationModel...depth=1, numNodes=3...
1171     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1172     >>> model.predict(test0.head().features)
1173     0.0
1174     >>> model.predictRaw(test0.head().features)
1175     DenseVector([1.0, 0.0])
1176     >>> model.predictProbability(test0.head().features)
1177     DenseVector([1.0, 0.0])
1178     >>> result = model.transform(test0).head()
1179     >>> result.prediction
1180     0.0
1181     >>> result.probability
1182     DenseVector([1.0, 0.0])
1183     >>> result.rawPrediction
1184     DenseVector([1.0, 0.0])
1185     >>> result.leafId
1186     0.0
1187     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1188     >>> model.transform(test1).head().prediction
1189     1.0
1190     >>> dtc_path = temp_path + "/dtc"
1191     >>> dt.save(dtc_path)
1192     >>> dt2 = DecisionTreeClassifier.load(dtc_path)
1193     >>> dt2.getMaxDepth()
1194     2
1195     >>> model_path = temp_path + "/dtc_model"
1196     >>> model.save(model_path)
1197     >>> model2 = DecisionTreeClassificationModel.load(model_path)
1198     >>> model.featureImportances == model2.featureImportances
1199     True
1200 
1201     >>> df3 = spark.createDataFrame([
1202     ...     (1.0, 0.2, Vectors.dense(1.0)),
1203     ...     (1.0, 0.8, Vectors.dense(1.0)),
1204     ...     (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
1205     >>> si3 = StringIndexer(inputCol="label", outputCol="indexed")
1206     >>> si_model3 = si3.fit(df3)
1207     >>> td3 = si_model3.transform(df3)
1208     >>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
1209     >>> model3 = dt3.fit(td3)
1210     >>> print(model3.toDebugString)
1211     DecisionTreeClassificationModel...depth=1, numNodes=3...
1212 
1213     .. versionadded:: 1.4.0
1214     """
1215 
1216     @keyword_only
1217     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1218                  probabilityCol="probability", rawPredictionCol="rawPrediction",
1219                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1220                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
1221                  seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
1222         """
1223         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1224                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
1225                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1226                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1227                  seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
1228         """
1229         super(DecisionTreeClassifier, self).__init__()
1230         self._java_obj = self._new_java_obj(
1231             "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
1232         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1233                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1234                          impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
1235         kwargs = self._input_kwargs
1236         self.setParams(**kwargs)
1237 
1238     @keyword_only
1239     @since("1.4.0")
1240     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1241                   probabilityCol="probability", rawPredictionCol="rawPrediction",
1242                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1243                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1244                   impurity="gini", seed=None, weightCol=None, leafCol="",
1245                   minWeightFractionPerNode=0.0):
1246         """
1247         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1248                   probabilityCol="probability", rawPredictionCol="rawPrediction", \
1249                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1250                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1251                   seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
1252         Sets params for the DecisionTreeClassifier.
1253         """
1254         kwargs = self._input_kwargs
1255         return self._set(**kwargs)
1256 
1257     def _create_model(self, java_model):
1258         return DecisionTreeClassificationModel(java_model)
1259 
1260     def setMaxDepth(self, value):
1261         """
1262         Sets the value of :py:attr:`maxDepth`.
1263         """
1264         return self._set(maxDepth=value)
1265 
1266     def setMaxBins(self, value):
1267         """
1268         Sets the value of :py:attr:`maxBins`.
1269         """
1270         return self._set(maxBins=value)
1271 
1272     def setMinInstancesPerNode(self, value):
1273         """
1274         Sets the value of :py:attr:`minInstancesPerNode`.
1275         """
1276         return self._set(minInstancesPerNode=value)
1277 
1278     @since("3.0.0")
1279     def setMinWeightFractionPerNode(self, value):
1280         """
1281         Sets the value of :py:attr:`minWeightFractionPerNode`.
1282         """
1283         return self._set(minWeightFractionPerNode=value)
1284 
1285     def setMinInfoGain(self, value):
1286         """
1287         Sets the value of :py:attr:`minInfoGain`.
1288         """
1289         return self._set(minInfoGain=value)
1290 
1291     def setMaxMemoryInMB(self, value):
1292         """
1293         Sets the value of :py:attr:`maxMemoryInMB`.
1294         """
1295         return self._set(maxMemoryInMB=value)
1296 
1297     def setCacheNodeIds(self, value):
1298         """
1299         Sets the value of :py:attr:`cacheNodeIds`.
1300         """
1301         return self._set(cacheNodeIds=value)
1302 
1303     @since("1.4.0")
1304     def setImpurity(self, value):
1305         """
1306         Sets the value of :py:attr:`impurity`.
1307         """
1308         return self._set(impurity=value)
1309 
1310     @since("1.4.0")
1311     def setCheckpointInterval(self, value):
1312         """
1313         Sets the value of :py:attr:`checkpointInterval`.
1314         """
1315         return self._set(checkpointInterval=value)
1316 
1317     def setSeed(self, value):
1318         """
1319         Sets the value of :py:attr:`seed`.
1320         """
1321         return self._set(seed=value)
1322 
1323     @since("3.0.0")
1324     def setWeightCol(self, value):
1325         """
1326         Sets the value of :py:attr:`weightCol`.
1327         """
1328         return self._set(weightCol=value)
1329 
1330 
1331 @inherit_doc
1332 class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
1333                                       _DecisionTreeClassifierParams, JavaMLWritable,
1334                                       JavaMLReadable):
1335     """
1336     Model fitted by DecisionTreeClassifier.
1337 
1338     .. versionadded:: 1.4.0
1339     """
1340 
1341     @property
1342     @since("2.0.0")
1343     def featureImportances(self):
1344         """
1345         Estimate of the importance of each feature.
1346 
1347         This generalizes the idea of "Gini" importance to other losses,
1348         following the explanation of Gini importance from "Random Forests" documentation
1349         by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
1350 
1351         This feature importance is calculated as follows:
1352           - importance(feature j) = sum (over nodes which split on feature j) of the gain,
1353             where gain is scaled by the number of instances passing through node
1354           - Normalize importances for tree to sum to 1.
1355 
1356         .. note:: Feature importance for single decision trees can have high variance due to
1357             correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
1358             to determine feature importance instead.
1359         """
1360         return self._call_java("featureImportances")
1361 
1362 
1363 @inherit_doc
1364 class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
1365     """
1366     Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.
1367     """
1368     pass
1369 
1370 
1371 @inherit_doc
1372 class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
1373                              JavaMLWritable, JavaMLReadable):
1374     """
1375     `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
1376     learning algorithm for classification.
1377     It supports both binary and multiclass labels, as well as both continuous and categorical
1378     features.
1379 
1380     >>> import numpy
1381     >>> from numpy import allclose
1382     >>> from pyspark.ml.linalg import Vectors
1383     >>> from pyspark.ml.feature import StringIndexer
1384     >>> df = spark.createDataFrame([
1385     ...     (1.0, Vectors.dense(1.0)),
1386     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1387     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1388     >>> si_model = stringIndexer.fit(df)
1389     >>> td = si_model.transform(df)
1390     >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
1391     ...     leafCol="leafId")
1392     >>> rf.getMinWeightFractionPerNode()
1393     0.0
1394     >>> model = rf.fit(td)
1395     >>> model.getLabelCol()
1396     'indexed'
1397     >>> model.setFeaturesCol("features")
1398     RandomForestClassificationModel...
1399     >>> model.setRawPredictionCol("newRawPrediction")
1400     RandomForestClassificationModel...
1401     >>> model.getBootstrap()
1402     True
1403     >>> model.getRawPredictionCol()
1404     'newRawPrediction'
1405     >>> model.featureImportances
1406     SparseVector(1, {0: 1.0})
1407     >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
1408     True
1409     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1410     >>> model.predict(test0.head().features)
1411     0.0
1412     >>> model.predictRaw(test0.head().features)
1413     DenseVector([2.0, 0.0])
1414     >>> model.predictProbability(test0.head().features)
1415     DenseVector([1.0, 0.0])
1416     >>> result = model.transform(test0).head()
1417     >>> result.prediction
1418     0.0
1419     >>> numpy.argmax(result.probability)
1420     0
1421     >>> numpy.argmax(result.newRawPrediction)
1422     0
1423     >>> result.leafId
1424     DenseVector([0.0, 0.0, 0.0])
1425     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1426     >>> model.transform(test1).head().prediction
1427     1.0
1428     >>> model.trees
1429     [DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]
1430     >>> rfc_path = temp_path + "/rfc"
1431     >>> rf.save(rfc_path)
1432     >>> rf2 = RandomForestClassifier.load(rfc_path)
1433     >>> rf2.getNumTrees()
1434     3
1435     >>> model_path = temp_path + "/rfc_model"
1436     >>> model.save(model_path)
1437     >>> model2 = RandomForestClassificationModel.load(model_path)
1438     >>> model.featureImportances == model2.featureImportances
1439     True
1440 
1441     .. versionadded:: 1.4.0
1442     """
1443 
1444     @keyword_only
1445     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1446                  probabilityCol="probability", rawPredictionCol="rawPrediction",
1447                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1448                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
1449                  numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
1450                  leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
1451         """
1452         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1453                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
1454                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1455                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
1456                  numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
1457                  leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
1458         """
1459         super(RandomForestClassifier, self).__init__()
1460         self._java_obj = self._new_java_obj(
1461             "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
1462         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1463                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1464                          impurity="gini", numTrees=20, featureSubsetStrategy="auto",
1465                          subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0,
1466                          bootstrap=True)
1467         kwargs = self._input_kwargs
1468         self.setParams(**kwargs)
1469 
1470     @keyword_only
1471     @since("1.4.0")
1472     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1473                   probabilityCol="probability", rawPredictionCol="rawPrediction",
1474                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1475                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
1476                   impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
1477                   leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
1478         """
1479         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1480                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
1481                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1482                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
1483                   impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
1484                   leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
1485         Sets params for linear classification.
1486         """
1487         kwargs = self._input_kwargs
1488         return self._set(**kwargs)
1489 
1490     def _create_model(self, java_model):
1491         return RandomForestClassificationModel(java_model)
1492 
1493     def setMaxDepth(self, value):
1494         """
1495         Sets the value of :py:attr:`maxDepth`.
1496         """
1497         return self._set(maxDepth=value)
1498 
1499     def setMaxBins(self, value):
1500         """
1501         Sets the value of :py:attr:`maxBins`.
1502         """
1503         return self._set(maxBins=value)
1504 
1505     def setMinInstancesPerNode(self, value):
1506         """
1507         Sets the value of :py:attr:`minInstancesPerNode`.
1508         """
1509         return self._set(minInstancesPerNode=value)
1510 
1511     def setMinInfoGain(self, value):
1512         """
1513         Sets the value of :py:attr:`minInfoGain`.
1514         """
1515         return self._set(minInfoGain=value)
1516 
1517     def setMaxMemoryInMB(self, value):
1518         """
1519         Sets the value of :py:attr:`maxMemoryInMB`.
1520         """
1521         return self._set(maxMemoryInMB=value)
1522 
1523     def setCacheNodeIds(self, value):
1524         """
1525         Sets the value of :py:attr:`cacheNodeIds`.
1526         """
1527         return self._set(cacheNodeIds=value)
1528 
1529     @since("1.4.0")
1530     def setImpurity(self, value):
1531         """
1532         Sets the value of :py:attr:`impurity`.
1533         """
1534         return self._set(impurity=value)
1535 
1536     @since("1.4.0")
1537     def setNumTrees(self, value):
1538         """
1539         Sets the value of :py:attr:`numTrees`.
1540         """
1541         return self._set(numTrees=value)
1542 
1543     @since("3.0.0")
1544     def setBootstrap(self, value):
1545         """
1546         Sets the value of :py:attr:`bootstrap`.
1547         """
1548         return self._set(bootstrap=value)
1549 
1550     @since("1.4.0")
1551     def setSubsamplingRate(self, value):
1552         """
1553         Sets the value of :py:attr:`subsamplingRate`.
1554         """
1555         return self._set(subsamplingRate=value)
1556 
1557     @since("2.4.0")
1558     def setFeatureSubsetStrategy(self, value):
1559         """
1560         Sets the value of :py:attr:`featureSubsetStrategy`.
1561         """
1562         return self._set(featureSubsetStrategy=value)
1563 
1564     def setSeed(self, value):
1565         """
1566         Sets the value of :py:attr:`seed`.
1567         """
1568         return self._set(seed=value)
1569 
1570     def setCheckpointInterval(self, value):
1571         """
1572         Sets the value of :py:attr:`checkpointInterval`.
1573         """
1574         return self._set(checkpointInterval=value)
1575 
1576     @since("3.0.0")
1577     def setWeightCol(self, value):
1578         """
1579         Sets the value of :py:attr:`weightCol`.
1580         """
1581         return self._set(weightCol=value)
1582 
1583     @since("3.0.0")
1584     def setMinWeightFractionPerNode(self, value):
1585         """
1586         Sets the value of :py:attr:`minWeightFractionPerNode`.
1587         """
1588         return self._set(minWeightFractionPerNode=value)
1589 
1590 
1591 class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
1592                                       _RandomForestClassifierParams, JavaMLWritable,
1593                                       JavaMLReadable):
1594     """
1595     Model fitted by RandomForestClassifier.
1596 
1597     .. versionadded:: 1.4.0
1598     """
1599 
1600     @property
1601     @since("2.0.0")
1602     def featureImportances(self):
1603         """
1604         Estimate of the importance of each feature.
1605 
1606         Each feature's importance is the average of its importance across all trees in the ensemble
1607         The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1608         (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1609         and follows the implementation from scikit-learn.
1610 
1611         .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
1612         """
1613         return self._call_java("featureImportances")
1614 
1615     @property
1616     @since("2.0.0")
1617     def trees(self):
1618         """Trees in this ensemble. Warning: These have null parent Estimators."""
1619         return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
1620 
1621 
1622 class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
1623     """
1624     Params for :py:class:`GBTClassifier` and :py:class:`GBTClassifierModel`.
1625 
1626     .. versionadded:: 3.0.0
1627     """
1628 
1629     supportedLossTypes = ["logistic"]
1630 
1631     lossType = Param(Params._dummy(), "lossType",
1632                      "Loss function which GBT tries to minimize (case-insensitive). " +
1633                      "Supported options: " + ", ".join(supportedLossTypes),
1634                      typeConverter=TypeConverters.toString)
1635 
1636     @since("1.4.0")
1637     def getLossType(self):
1638         """
1639         Gets the value of lossType or its default value.
1640         """
1641         return self.getOrDefault(self.lossType)
1642 
1643 
1644 @inherit_doc
1645 class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
1646                     JavaMLWritable, JavaMLReadable):
1647     """
1648     `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
1649     learning algorithm for classification.
1650     It supports binary labels, as well as both continuous and categorical features.
1651 
1652     The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
1653 
1654     Notes on Gradient Boosting vs. TreeBoost:
1655     - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
1656     - Both algorithms learn tree ensembles by minimizing loss functions.
1657     - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
1658     based on the loss function, whereas the original gradient boosting method does not.
1659     - We expect to implement TreeBoost in the future:
1660     `SPARK-4240 <https://issues.apache.org/jira/browse/SPARK-4240>`_
1661 
1662     .. note:: Multiclass labels are not currently supported.
1663 
1664     >>> from numpy import allclose
1665     >>> from pyspark.ml.linalg import Vectors
1666     >>> from pyspark.ml.feature import StringIndexer
1667     >>> df = spark.createDataFrame([
1668     ...     (1.0, Vectors.dense(1.0)),
1669     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
1670     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
1671     >>> si_model = stringIndexer.fit(df)
1672     >>> td = si_model.transform(df)
1673     >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,
1674     ...     leafCol="leafId")
1675     >>> gbt.setMaxIter(5)
1676     GBTClassifier...
1677     >>> gbt.setMinWeightFractionPerNode(0.049)
1678     GBTClassifier...
1679     >>> gbt.getMaxIter()
1680     5
1681     >>> gbt.getFeatureSubsetStrategy()
1682     'all'
1683     >>> model = gbt.fit(td)
1684     >>> model.getLabelCol()
1685     'indexed'
1686     >>> model.setFeaturesCol("features")
1687     GBTClassificationModel...
1688     >>> model.setThresholds([0.3, 0.7])
1689     GBTClassificationModel...
1690     >>> model.getThresholds()
1691     [0.3, 0.7]
1692     >>> model.featureImportances
1693     SparseVector(1, {0: 1.0})
1694     >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
1695     True
1696     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1697     >>> model.predict(test0.head().features)
1698     0.0
1699     >>> model.predictRaw(test0.head().features)
1700     DenseVector([1.1697, -1.1697])
1701     >>> model.predictProbability(test0.head().features)
1702     DenseVector([0.9121, 0.0879])
1703     >>> result = model.transform(test0).head()
1704     >>> result.prediction
1705     0.0
1706     >>> result.leafId
1707     DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
1708     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
1709     >>> model.transform(test1).head().prediction
1710     1.0
1711     >>> model.totalNumNodes
1712     15
1713     >>> print(model.toDebugString)
1714     GBTClassificationModel...numTrees=5...
1715     >>> gbtc_path = temp_path + "gbtc"
1716     >>> gbt.save(gbtc_path)
1717     >>> gbt2 = GBTClassifier.load(gbtc_path)
1718     >>> gbt2.getMaxDepth()
1719     2
1720     >>> model_path = temp_path + "gbtc_model"
1721     >>> model.save(model_path)
1722     >>> model2 = GBTClassificationModel.load(model_path)
1723     >>> model.featureImportances == model2.featureImportances
1724     True
1725     >>> model.treeWeights == model2.treeWeights
1726     True
1727     >>> model.trees
1728     [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
1729     >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
1730     ...              ["indexed", "features"])
1731     >>> model.evaluateEachIteration(validation)
1732     [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
1733     >>> model.numClasses
1734     2
1735     >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
1736     >>> gbt.getValidationIndicatorCol()
1737     'validationIndicator'
1738     >>> gbt.getValidationTol()
1739     0.01
1740 
1741     .. versionadded:: 1.4.0
1742     """
1743 
1744     @keyword_only
1745     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1746                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1747                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
1748                  maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
1749                  featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
1750                  leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
1751         """
1752         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1753                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1754                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1755                  lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
1756                  impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1757                  validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1758                  weightCol=None)
1759         """
1760         super(GBTClassifier, self).__init__()
1761         self._java_obj = self._new_java_obj(
1762             "org.apache.spark.ml.classification.GBTClassifier", self.uid)
1763         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1764                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1765                          lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
1766                          impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1767                          leafCol="", minWeightFractionPerNode=0.0)
1768         kwargs = self._input_kwargs
1769         self.setParams(**kwargs)
1770 
1771     @keyword_only
1772     @since("1.4.0")
1773     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1774                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
1775                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
1776                   lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
1777                   impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1778                   validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
1779                   weightCol=None):
1780         """
1781         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1782                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
1783                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
1784                   lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
1785                   impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1786                   validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
1787                   weightCol=None)
1788         Sets params for Gradient Boosted Tree Classification.
1789         """
1790         kwargs = self._input_kwargs
1791         return self._set(**kwargs)
1792 
1793     def _create_model(self, java_model):
1794         return GBTClassificationModel(java_model)
1795 
1796     def setMaxDepth(self, value):
1797         """
1798         Sets the value of :py:attr:`maxDepth`.
1799         """
1800         return self._set(maxDepth=value)
1801 
1802     def setMaxBins(self, value):
1803         """
1804         Sets the value of :py:attr:`maxBins`.
1805         """
1806         return self._set(maxBins=value)
1807 
1808     def setMinInstancesPerNode(self, value):
1809         """
1810         Sets the value of :py:attr:`minInstancesPerNode`.
1811         """
1812         return self._set(minInstancesPerNode=value)
1813 
1814     def setMinInfoGain(self, value):
1815         """
1816         Sets the value of :py:attr:`minInfoGain`.
1817         """
1818         return self._set(minInfoGain=value)
1819 
1820     def setMaxMemoryInMB(self, value):
1821         """
1822         Sets the value of :py:attr:`maxMemoryInMB`.
1823         """
1824         return self._set(maxMemoryInMB=value)
1825 
1826     def setCacheNodeIds(self, value):
1827         """
1828         Sets the value of :py:attr:`cacheNodeIds`.
1829         """
1830         return self._set(cacheNodeIds=value)
1831 
1832     @since("1.4.0")
1833     def setImpurity(self, value):
1834         """
1835         Sets the value of :py:attr:`impurity`.
1836         """
1837         return self._set(impurity=value)
1838 
1839     @since("1.4.0")
1840     def setLossType(self, value):
1841         """
1842         Sets the value of :py:attr:`lossType`.
1843         """
1844         return self._set(lossType=value)
1845 
1846     @since("1.4.0")
1847     def setSubsamplingRate(self, value):
1848         """
1849         Sets the value of :py:attr:`subsamplingRate`.
1850         """
1851         return self._set(subsamplingRate=value)
1852 
1853     @since("2.4.0")
1854     def setFeatureSubsetStrategy(self, value):
1855         """
1856         Sets the value of :py:attr:`featureSubsetStrategy`.
1857         """
1858         return self._set(featureSubsetStrategy=value)
1859 
1860     @since("3.0.0")
1861     def setValidationIndicatorCol(self, value):
1862         """
1863         Sets the value of :py:attr:`validationIndicatorCol`.
1864         """
1865         return self._set(validationIndicatorCol=value)
1866 
1867     @since("1.4.0")
1868     def setMaxIter(self, value):
1869         """
1870         Sets the value of :py:attr:`maxIter`.
1871         """
1872         return self._set(maxIter=value)
1873 
1874     @since("1.4.0")
1875     def setCheckpointInterval(self, value):
1876         """
1877         Sets the value of :py:attr:`checkpointInterval`.
1878         """
1879         return self._set(checkpointInterval=value)
1880 
1881     @since("1.4.0")
1882     def setSeed(self, value):
1883         """
1884         Sets the value of :py:attr:`seed`.
1885         """
1886         return self._set(seed=value)
1887 
1888     @since("1.4.0")
1889     def setStepSize(self, value):
1890         """
1891         Sets the value of :py:attr:`stepSize`.
1892         """
1893         return self._set(stepSize=value)
1894 
1895     @since("3.0.0")
1896     def setWeightCol(self, value):
1897         """
1898         Sets the value of :py:attr:`weightCol`.
1899         """
1900         return self._set(weightCol=value)
1901 
1902     @since("3.0.0")
1903     def setMinWeightFractionPerNode(self, value):
1904         """
1905         Sets the value of :py:attr:`minWeightFractionPerNode`.
1906         """
1907         return self._set(minWeightFractionPerNode=value)
1908 
1909 
1910 class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
1911                              _GBTClassifierParams, JavaMLWritable, JavaMLReadable):
1912     """
1913     Model fitted by GBTClassifier.
1914 
1915     .. versionadded:: 1.4.0
1916     """
1917 
1918     @property
1919     @since("2.0.0")
1920     def featureImportances(self):
1921         """
1922         Estimate of the importance of each feature.
1923 
1924         Each feature's importance is the average of its importance across all trees in the ensemble
1925         The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
1926         (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
1927         and follows the implementation from scikit-learn.
1928 
1929         .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
1930         """
1931         return self._call_java("featureImportances")
1932 
1933     @property
1934     @since("2.0.0")
1935     def trees(self):
1936         """Trees in this ensemble. Warning: These have null parent Estimators."""
1937         return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
1938 
1939     @since("2.4.0")
1940     def evaluateEachIteration(self, dataset):
1941         """
1942         Method to compute error or loss for every iteration of gradient boosting.
1943 
1944         :param dataset:
1945             Test dataset to evaluate model on, where dataset is an
1946             instance of :py:class:`pyspark.sql.DataFrame`
1947         """
1948         return self._call_java("evaluateEachIteration", dataset)
1949 
1950 
1951 class _NaiveBayesParams(_JavaPredictorParams, HasWeightCol):
1952     """
1953     Params for :py:class:`NaiveBayes` and :py:class:`NaiveBayesModel`.
1954 
1955     .. versionadded:: 3.0.0
1956     """
1957 
1958     smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
1959                       "default is 1.0", typeConverter=TypeConverters.toFloat)
1960     modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
1961                       "(case-sensitive). Supported options: multinomial (default), bernoulli " +
1962                       "and gaussian.",
1963                       typeConverter=TypeConverters.toString)
1964 
1965     @since("1.5.0")
1966     def getSmoothing(self):
1967         """
1968         Gets the value of smoothing or its default value.
1969         """
1970         return self.getOrDefault(self.smoothing)
1971 
1972     @since("1.5.0")
1973     def getModelType(self):
1974         """
1975         Gets the value of modelType or its default value.
1976         """
1977         return self.getOrDefault(self.modelType)
1978 
1979 
1980 @inherit_doc
1981 class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, HasWeightCol,
1982                  JavaMLWritable, JavaMLReadable):
1983     """
1984     Naive Bayes Classifiers.
1985     It supports both Multinomial and Bernoulli NB. `Multinomial NB
1986     <http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html>`_
1987     can handle finitely supported discrete data. For example, by converting documents into
1988     TF-IDF vectors, it can be used for document classification. By making every vector a
1989     binary (0/1) data, it can also be used as `Bernoulli NB
1990     <http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html>`_.
1991     The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.
1992     Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB.
1993     Specifically, Complement NB uses statistics from the complement of each class to compute
1994     the model's coefficients. The inventors of Complement NB show empirically that the parameter
1995     estimates for CNB are more stable than those for Multinomial NB. Like Multinomial NB, the
1996     input feature values for Complement NB must be nonnegative.
1997     Since 3.0.0, it also supports Gaussian NB
1998     <https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes>`_.
1999     which can handle continuous data.
2000 
2001     >>> from pyspark.sql import Row
2002     >>> from pyspark.ml.linalg import Vectors
2003     >>> df = spark.createDataFrame([
2004     ...     Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
2005     ...     Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
2006     ...     Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])
2007     >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
2008     >>> model = nb.fit(df)
2009     >>> model.setFeaturesCol("features")
2010     NaiveBayesModel...
2011     >>> model.getSmoothing()
2012     1.0
2013     >>> model.pi
2014     DenseVector([-0.81..., -0.58...])
2015     >>> model.theta
2016     DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
2017     >>> model.sigma
2018     DenseMatrix(0, 0, [...], ...)
2019     >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
2020     >>> model.predict(test0.head().features)
2021     1.0
2022     >>> model.predictRaw(test0.head().features)
2023     DenseVector([-1.72..., -0.99...])
2024     >>> model.predictProbability(test0.head().features)
2025     DenseVector([0.32..., 0.67...])
2026     >>> result = model.transform(test0).head()
2027     >>> result.prediction
2028     1.0
2029     >>> result.probability
2030     DenseVector([0.32..., 0.67...])
2031     >>> result.rawPrediction
2032     DenseVector([-1.72..., -0.99...])
2033     >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
2034     >>> model.transform(test1).head().prediction
2035     1.0
2036     >>> nb_path = temp_path + "/nb"
2037     >>> nb.save(nb_path)
2038     >>> nb2 = NaiveBayes.load(nb_path)
2039     >>> nb2.getSmoothing()
2040     1.0
2041     >>> model_path = temp_path + "/nb_model"
2042     >>> model.save(model_path)
2043     >>> model2 = NaiveBayesModel.load(model_path)
2044     >>> model.pi == model2.pi
2045     True
2046     >>> model.theta == model2.theta
2047     True
2048     >>> nb = nb.setThresholds([0.01, 10.00])
2049     >>> model3 = nb.fit(df)
2050     >>> result = model3.transform(test0).head()
2051     >>> result.prediction
2052     0.0
2053     >>> nb3 = NaiveBayes().setModelType("gaussian")
2054     >>> model4 = nb3.fit(df)
2055     >>> model4.getModelType()
2056     'gaussian'
2057     >>> model4.sigma
2058     DenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)
2059     >>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")
2060     >>> model5 = nb5.fit(df)
2061     >>> model5.getModelType()
2062     'complement'
2063     >>> model5.theta
2064     DenseMatrix(2, 2, [...], 1)
2065     >>> model5.sigma
2066     DenseMatrix(0, 0, [...], ...)
2067 
2068     .. versionadded:: 1.5.0
2069     """
2070 
2071     @keyword_only
2072     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2073                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
2074                  modelType="multinomial", thresholds=None, weightCol=None):
2075         """
2076         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2077                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
2078                  modelType="multinomial", thresholds=None, weightCol=None)
2079         """
2080         super(NaiveBayes, self).__init__()
2081         self._java_obj = self._new_java_obj(
2082             "org.apache.spark.ml.classification.NaiveBayes", self.uid)
2083         self._setDefault(smoothing=1.0, modelType="multinomial")
2084         kwargs = self._input_kwargs
2085         self.setParams(**kwargs)
2086 
2087     @keyword_only
2088     @since("1.5.0")
2089     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2090                   probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
2091                   modelType="multinomial", thresholds=None, weightCol=None):
2092         """
2093         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2094                   probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
2095                   modelType="multinomial", thresholds=None, weightCol=None)
2096         Sets params for Naive Bayes.
2097         """
2098         kwargs = self._input_kwargs
2099         return self._set(**kwargs)
2100 
2101     def _create_model(self, java_model):
2102         return NaiveBayesModel(java_model)
2103 
2104     @since("1.5.0")
2105     def setSmoothing(self, value):
2106         """
2107         Sets the value of :py:attr:`smoothing`.
2108         """
2109         return self._set(smoothing=value)
2110 
2111     @since("1.5.0")
2112     def setModelType(self, value):
2113         """
2114         Sets the value of :py:attr:`modelType`.
2115         """
2116         return self._set(modelType=value)
2117 
2118     def setWeightCol(self, value):
2119         """
2120         Sets the value of :py:attr:`weightCol`.
2121         """
2122         return self._set(weightCol=value)
2123 
2124 
2125 class NaiveBayesModel(JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
2126                       JavaMLReadable):
2127     """
2128     Model fitted by NaiveBayes.
2129 
2130     .. versionadded:: 1.5.0
2131     """
2132 
2133     @property
2134     @since("2.0.0")
2135     def pi(self):
2136         """
2137         log of class priors.
2138         """
2139         return self._call_java("pi")
2140 
2141     @property
2142     @since("2.0.0")
2143     def theta(self):
2144         """
2145         log of class conditional probabilities.
2146         """
2147         return self._call_java("theta")
2148 
2149     @property
2150     @since("3.0.0")
2151     def sigma(self):
2152         """
2153         variance of each feature.
2154         """
2155         return self._call_java("sigma")
2156 
2157 
2158 class _MultilayerPerceptronParams(_JavaProbabilisticClassifierParams, HasSeed, HasMaxIter,
2159                                   HasTol, HasStepSize, HasSolver, HasBlockSize):
2160     """
2161     Params for :py:class:`MultilayerPerceptronClassifier`.
2162 
2163     .. versionadded:: 3.0.0
2164     """
2165 
2166     layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
2167                    "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
2168                    "neurons and output layer of 10 neurons.",
2169                    typeConverter=TypeConverters.toListInt)
2170     solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
2171                    "options: l-bfgs, gd.", typeConverter=TypeConverters.toString)
2172     initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
2173                            typeConverter=TypeConverters.toVector)
2174 
2175     @since("1.6.0")
2176     def getLayers(self):
2177         """
2178         Gets the value of layers or its default value.
2179         """
2180         return self.getOrDefault(self.layers)
2181 
2182     @since("2.0.0")
2183     def getInitialWeights(self):
2184         """
2185         Gets the value of initialWeights or its default value.
2186         """
2187         return self.getOrDefault(self.initialWeights)
2188 
2189 
2190 @inherit_doc
2191 class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPerceptronParams,
2192                                      JavaMLWritable, JavaMLReadable):
2193     """
2194     Classifier trainer based on the Multilayer Perceptron.
2195     Each layer has sigmoid activation function, output layer has softmax.
2196     Number of inputs has to be equal to the size of feature vectors.
2197     Number of outputs has to be equal to the total number of labels.
2198 
2199     >>> from pyspark.ml.linalg import Vectors
2200     >>> df = spark.createDataFrame([
2201     ...     (0.0, Vectors.dense([0.0, 0.0])),
2202     ...     (1.0, Vectors.dense([0.0, 1.0])),
2203     ...     (1.0, Vectors.dense([1.0, 0.0])),
2204     ...     (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
2205     >>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
2206     >>> mlp.setMaxIter(100)
2207     MultilayerPerceptronClassifier...
2208     >>> mlp.getMaxIter()
2209     100
2210     >>> mlp.getBlockSize()
2211     128
2212     >>> mlp.setBlockSize(1)
2213     MultilayerPerceptronClassifier...
2214     >>> mlp.getBlockSize()
2215     1
2216     >>> model = mlp.fit(df)
2217     >>> model.setFeaturesCol("features")
2218     MultilayerPerceptronClassificationModel...
2219     >>> model.getMaxIter()
2220     100
2221     >>> model.getLayers()
2222     [2, 2, 2]
2223     >>> model.weights.size
2224     12
2225     >>> testDF = spark.createDataFrame([
2226     ...     (Vectors.dense([1.0, 0.0]),),
2227     ...     (Vectors.dense([0.0, 0.0]),)], ["features"])
2228     >>> model.predict(testDF.head().features)
2229     1.0
2230     >>> model.predictRaw(testDF.head().features)
2231     DenseVector([-16.208, 16.344])
2232     >>> model.predictProbability(testDF.head().features)
2233     DenseVector([0.0, 1.0])
2234     >>> model.transform(testDF).select("features", "prediction").show()
2235     +---------+----------+
2236     | features|prediction|
2237     +---------+----------+
2238     |[1.0,0.0]|       1.0|
2239     |[0.0,0.0]|       0.0|
2240     +---------+----------+
2241     ...
2242     >>> mlp_path = temp_path + "/mlp"
2243     >>> mlp.save(mlp_path)
2244     >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
2245     >>> mlp2.getBlockSize()
2246     1
2247     >>> model_path = temp_path + "/mlp_model"
2248     >>> model.save(model_path)
2249     >>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
2250     >>> model.getLayers() == model2.getLayers()
2251     True
2252     >>> model.weights == model2.weights
2253     True
2254     >>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))
2255     >>> model3 = mlp2.fit(df)
2256     >>> model3.weights != model2.weights
2257     True
2258     >>> model3.getLayers() == model.getLayers()
2259     True
2260 
2261     .. versionadded:: 1.6.0
2262     """
2263 
2264     @keyword_only
2265     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2266                  maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
2267                  solver="l-bfgs", initialWeights=None, probabilityCol="probability",
2268                  rawPredictionCol="rawPrediction"):
2269         """
2270         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2271                  maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
2272                  solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
2273                  rawPredictionCol="rawPrediction")
2274         """
2275         super(MultilayerPerceptronClassifier, self).__init__()
2276         self._java_obj = self._new_java_obj(
2277             "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
2278         self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
2279         kwargs = self._input_kwargs
2280         self.setParams(**kwargs)
2281 
2282     @keyword_only
2283     @since("1.6.0")
2284     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2285                   maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
2286                   solver="l-bfgs", initialWeights=None, probabilityCol="probability",
2287                   rawPredictionCol="rawPrediction"):
2288         """
2289         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2290                   maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
2291                   solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
2292                   rawPredictionCol="rawPrediction"):
2293         Sets params for MultilayerPerceptronClassifier.
2294         """
2295         kwargs = self._input_kwargs
2296         return self._set(**kwargs)
2297 
2298     def _create_model(self, java_model):
2299         return MultilayerPerceptronClassificationModel(java_model)
2300 
2301     @since("1.6.0")
2302     def setLayers(self, value):
2303         """
2304         Sets the value of :py:attr:`layers`.
2305         """
2306         return self._set(layers=value)
2307 
2308     @since("1.6.0")
2309     def setBlockSize(self, value):
2310         """
2311         Sets the value of :py:attr:`blockSize`.
2312         """
2313         return self._set(blockSize=value)
2314 
2315     @since("2.0.0")
2316     def setInitialWeights(self, value):
2317         """
2318         Sets the value of :py:attr:`initialWeights`.
2319         """
2320         return self._set(initialWeights=value)
2321 
2322     def setMaxIter(self, value):
2323         """
2324         Sets the value of :py:attr:`maxIter`.
2325         """
2326         return self._set(maxIter=value)
2327 
2328     def setSeed(self, value):
2329         """
2330         Sets the value of :py:attr:`seed`.
2331         """
2332         return self._set(seed=value)
2333 
2334     def setTol(self, value):
2335         """
2336         Sets the value of :py:attr:`tol`.
2337         """
2338         return self._set(tol=value)
2339 
2340     @since("2.0.0")
2341     def setStepSize(self, value):
2342         """
2343         Sets the value of :py:attr:`stepSize`.
2344         """
2345         return self._set(stepSize=value)
2346 
2347     def setSolver(self, value):
2348         """
2349         Sets the value of :py:attr:`solver`.
2350         """
2351         return self._set(solver=value)
2352 
2353 
2354 class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel,
2355                                               _MultilayerPerceptronParams, JavaMLWritable,
2356                                               JavaMLReadable):
2357     """
2358     Model fitted by MultilayerPerceptronClassifier.
2359 
2360     .. versionadded:: 1.6.0
2361     """
2362 
2363     @property
2364     @since("2.0.0")
2365     def weights(self):
2366         """
2367         the weights of layers.
2368         """
2369         return self._call_java("weights")
2370 
2371 
2372 class _OneVsRestParams(_JavaClassifierParams, HasWeightCol):
2373     """
2374     Params for :py:class:`OneVsRest` and :py:class:`OneVsRestModelModel`.
2375     """
2376 
2377     classifier = Param(Params._dummy(), "classifier", "base binary classifier")
2378 
2379     @since("2.0.0")
2380     def getClassifier(self):
2381         """
2382         Gets the value of classifier or its default value.
2383         """
2384         return self.getOrDefault(self.classifier)
2385 
2386 
2387 @inherit_doc
2388 class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, JavaMLReadable, JavaMLWritable):
2389     """
2390     Reduction of Multiclass Classification to Binary Classification.
2391     Performs reduction using one against all strategy.
2392     For a multiclass classification with k classes, train k models (one per class).
2393     Each example is scored against all k models and the model with highest score
2394     is picked to label the example.
2395 
2396     >>> from pyspark.sql import Row
2397     >>> from pyspark.ml.linalg import Vectors
2398     >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
2399     >>> df = spark.read.format("libsvm").load(data_path)
2400     >>> lr = LogisticRegression(regParam=0.01)
2401     >>> ovr = OneVsRest(classifier=lr)
2402     >>> ovr.getRawPredictionCol()
2403     'rawPrediction'
2404     >>> ovr.setPredictionCol("newPrediction")
2405     OneVsRest...
2406     >>> model = ovr.fit(df)
2407     >>> model.models[0].coefficients
2408     DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
2409     >>> model.models[1].coefficients
2410     DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])
2411     >>> model.models[2].coefficients
2412     DenseVector([0.3..., -3.4..., 1.0..., -1.1...])
2413     >>> [x.intercept for x in model.models]
2414     [-2.7..., -2.5..., -1.3...]
2415     >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()
2416     >>> model.transform(test0).head().newPrediction
2417     0.0
2418     >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()
2419     >>> model.transform(test1).head().newPrediction
2420     2.0
2421     >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()
2422     >>> model.transform(test2).head().newPrediction
2423     0.0
2424     >>> model_path = temp_path + "/ovr_model"
2425     >>> model.save(model_path)
2426     >>> model2 = OneVsRestModel.load(model_path)
2427     >>> model2.transform(test0).head().newPrediction
2428     0.0
2429     >>> model.transform(test2).columns
2430     ['features', 'rawPrediction', 'newPrediction']
2431 
2432     .. versionadded:: 2.0.0
2433     """
2434 
2435     @keyword_only
2436     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2437                  rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2438         """
2439         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2440                  rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2441         """
2442         super(OneVsRest, self).__init__()
2443         self._setDefault(parallelism=1)
2444         kwargs = self._input_kwargs
2445         self._set(**kwargs)
2446 
2447     @keyword_only
2448     @since("2.0.0")
2449     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2450                   rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2451         """
2452         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2453                   rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
2454         Sets params for OneVsRest.
2455         """
2456         kwargs = self._input_kwargs
2457         return self._set(**kwargs)
2458 
2459     @since("2.0.0")
2460     def setClassifier(self, value):
2461         """
2462         Sets the value of :py:attr:`classifier`.
2463         """
2464         return self._set(classifier=value)
2465 
2466     def setLabelCol(self, value):
2467         """
2468         Sets the value of :py:attr:`labelCol`.
2469         """
2470         return self._set(labelCol=value)
2471 
2472     def setFeaturesCol(self, value):
2473         """
2474         Sets the value of :py:attr:`featuresCol`.
2475         """
2476         return self._set(featuresCol=value)
2477 
2478     def setPredictionCol(self, value):
2479         """
2480         Sets the value of :py:attr:`predictionCol`.
2481         """
2482         return self._set(predictionCol=value)
2483 
2484     def setRawPredictionCol(self, value):
2485         """
2486         Sets the value of :py:attr:`rawPredictionCol`.
2487         """
2488         return self._set(rawPredictionCol=value)
2489 
2490     def setWeightCol(self, value):
2491         """
2492         Sets the value of :py:attr:`weightCol`.
2493         """
2494         return self._set(weightCol=value)
2495 
2496     def setParallelism(self, value):
2497         """
2498         Sets the value of :py:attr:`parallelism`.
2499         """
2500         return self._set(parallelism=value)
2501 
2502     def _fit(self, dataset):
2503         labelCol = self.getLabelCol()
2504         featuresCol = self.getFeaturesCol()
2505         predictionCol = self.getPredictionCol()
2506         classifier = self.getClassifier()
2507 
2508         numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
2509 
2510         weightCol = None
2511         if (self.isDefined(self.weightCol) and self.getWeightCol()):
2512             if isinstance(classifier, HasWeightCol):
2513                 weightCol = self.getWeightCol()
2514             else:
2515                 warnings.warn("weightCol is ignored, "
2516                               "as it is not supported by {} now.".format(classifier))
2517 
2518         if weightCol:
2519             multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
2520         else:
2521             multiclassLabeled = dataset.select(labelCol, featuresCol)
2522 
2523         # persist if underlying dataset is not persistent.
2524         handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
2525         if handlePersistence:
2526             multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
2527 
2528         def trainSingleClass(index):
2529             binaryLabelCol = "mc2b$" + str(index)
2530             trainingDataset = multiclassLabeled.withColumn(
2531                 binaryLabelCol,
2532                 when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0))
2533             paramMap = dict([(classifier.labelCol, binaryLabelCol),
2534                             (classifier.featuresCol, featuresCol),
2535                             (classifier.predictionCol, predictionCol)])
2536             if weightCol:
2537                 paramMap[classifier.weightCol] = weightCol
2538             return classifier.fit(trainingDataset, paramMap)
2539 
2540         pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
2541 
2542         models = pool.map(trainSingleClass, range(numClasses))
2543 
2544         if handlePersistence:
2545             multiclassLabeled.unpersist()
2546 
2547         return self._copyValues(OneVsRestModel(models=models))
2548 
2549     @since("2.0.0")
2550     def copy(self, extra=None):
2551         """
2552         Creates a copy of this instance with a randomly generated uid
2553         and some extra params. This creates a deep copy of the embedded paramMap,
2554         and copies the embedded and extra parameters over.
2555 
2556         :param extra: Extra parameters to copy to the new instance
2557         :return: Copy of this instance
2558         """
2559         if extra is None:
2560             extra = dict()
2561         newOvr = Params.copy(self, extra)
2562         if self.isSet(self.classifier):
2563             newOvr.setClassifier(self.getClassifier().copy(extra))
2564         return newOvr
2565 
2566     @classmethod
2567     def _from_java(cls, java_stage):
2568         """
2569         Given a Java OneVsRest, create and return a Python wrapper of it.
2570         Used for ML persistence.
2571         """
2572         featuresCol = java_stage.getFeaturesCol()
2573         labelCol = java_stage.getLabelCol()
2574         predictionCol = java_stage.getPredictionCol()
2575         rawPredictionCol = java_stage.getRawPredictionCol()
2576         classifier = JavaParams._from_java(java_stage.getClassifier())
2577         parallelism = java_stage.getParallelism()
2578         py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
2579                        rawPredictionCol=rawPredictionCol, classifier=classifier,
2580                        parallelism=parallelism)
2581         if java_stage.isDefined(java_stage.getParam("weightCol")):
2582             py_stage.setWeightCol(java_stage.getWeightCol())
2583         py_stage._resetUid(java_stage.uid())
2584         return py_stage
2585 
2586     def _to_java(self):
2587         """
2588         Transfer this instance to a Java OneVsRest. Used for ML persistence.
2589 
2590         :return: Java object equivalent to this instance.
2591         """
2592         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
2593                                              self.uid)
2594         _java_obj.setClassifier(self.getClassifier()._to_java())
2595         _java_obj.setParallelism(self.getParallelism())
2596         _java_obj.setFeaturesCol(self.getFeaturesCol())
2597         _java_obj.setLabelCol(self.getLabelCol())
2598         _java_obj.setPredictionCol(self.getPredictionCol())
2599         if (self.isDefined(self.weightCol) and self.getWeightCol()):
2600             _java_obj.setWeightCol(self.getWeightCol())
2601         _java_obj.setRawPredictionCol(self.getRawPredictionCol())
2602         return _java_obj
2603 
2604     def _make_java_param_pair(self, param, value):
2605         """
2606         Makes a Java param pair.
2607         """
2608         sc = SparkContext._active_spark_context
2609         param = self._resolveParam(param)
2610         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
2611                                              self.uid)
2612         java_param = _java_obj.getParam(param.name)
2613         if isinstance(value, JavaParams):
2614             # used in the case of an estimator having another estimator as a parameter
2615             # the reason why this is not in _py2java in common.py is that importing
2616             # Estimator and Model in common.py results in a circular import with inherit_doc
2617             java_value = value._to_java()
2618         else:
2619             java_value = _py2java(sc, value)
2620         return java_param.w(java_value)
2621 
2622     def _transfer_param_map_to_java(self, pyParamMap):
2623         """
2624         Transforms a Python ParamMap into a Java ParamMap.
2625         """
2626         paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
2627         for param in self.params:
2628             if param in pyParamMap:
2629                 pair = self._make_java_param_pair(param, pyParamMap[param])
2630                 paramMap.put([pair])
2631         return paramMap
2632 
2633     def _transfer_param_map_from_java(self, javaParamMap):
2634         """
2635         Transforms a Java ParamMap into a Python ParamMap.
2636         """
2637         sc = SparkContext._active_spark_context
2638         paramMap = dict()
2639         for pair in javaParamMap.toList():
2640             param = pair.param()
2641             if self.hasParam(str(param.name())):
2642                 if param.name() == "classifier":
2643                     paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
2644                 else:
2645                     paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
2646         return paramMap
2647 
2648 
2649 class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
2650     """
2651     Model fitted by OneVsRest.
2652     This stores the models resulting from training k binary classifiers: one for each class.
2653     Each example is scored against all k models, and the model with the highest score
2654     is picked to label the example.
2655 
2656     .. versionadded:: 2.0.0
2657     """
2658 
2659     def setFeaturesCol(self, value):
2660         """
2661         Sets the value of :py:attr:`featuresCol`.
2662         """
2663         return self._set(featuresCol=value)
2664 
2665     def setPredictionCol(self, value):
2666         """
2667         Sets the value of :py:attr:`predictionCol`.
2668         """
2669         return self._set(predictionCol=value)
2670 
2671     def setRawPredictionCol(self, value):
2672         """
2673         Sets the value of :py:attr:`rawPredictionCol`.
2674         """
2675         return self._set(rawPredictionCol=value)
2676 
2677     def __init__(self, models):
2678         super(OneVsRestModel, self).__init__()
2679         self.models = models
2680         java_models = [model._to_java() for model in self.models]
2681         sc = SparkContext._active_spark_context
2682         java_models_array = JavaWrapper._new_java_array(java_models,
2683                                                         sc._gateway.jvm.org.apache.spark.ml
2684                                                         .classification.ClassificationModel)
2685         # TODO: need to set metadata
2686         metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
2687         self._java_obj = \
2688             JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
2689                                      self.uid, metadata.empty(), java_models_array)
2690 
2691     def _transform(self, dataset):
2692         # determine the input columns: these need to be passed through
2693         origCols = dataset.columns
2694 
2695         # add an accumulator column to store predictions of all the models
2696         accColName = "mbc$acc" + str(uuid.uuid4())
2697         initUDF = udf(lambda _: [], ArrayType(DoubleType()))
2698         newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
2699 
2700         # persist if underlying dataset is not persistent.
2701         handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
2702         if handlePersistence:
2703             newDataset.persist(StorageLevel.MEMORY_AND_DISK)
2704 
2705         # update the accumulator column with the result of prediction of models
2706         aggregatedDataset = newDataset
2707         for index, model in enumerate(self.models):
2708             rawPredictionCol = self.getRawPredictionCol()
2709 
2710             columns = origCols + [rawPredictionCol, accColName]
2711 
2712             # add temporary column to store intermediate scores and update
2713             tmpColName = "mbc$tmp" + str(uuid.uuid4())
2714             updateUDF = udf(
2715                 lambda predictions, prediction: predictions + [prediction.tolist()[1]],
2716                 ArrayType(DoubleType()))
2717             transformedDataset = model.transform(aggregatedDataset).select(*columns)
2718             updatedDataset = transformedDataset.withColumn(
2719                 tmpColName,
2720                 updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
2721             newColumns = origCols + [tmpColName]
2722 
2723             # switch out the intermediate column with the accumulator column
2724             aggregatedDataset = updatedDataset\
2725                 .select(*newColumns).withColumnRenamed(tmpColName, accColName)
2726 
2727         if handlePersistence:
2728             newDataset.unpersist()
2729 
2730         if self.getRawPredictionCol():
2731             def func(predictions):
2732                 predArray = []
2733                 for x in predictions:
2734                     predArray.append(x)
2735                 return Vectors.dense(predArray)
2736 
2737             rawPredictionUDF = udf(func)
2738             aggregatedDataset = aggregatedDataset.withColumn(
2739                 self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
2740 
2741         if self.getPredictionCol():
2742             # output the index of the classifier with highest confidence as prediction
2743             labelUDF = udf(lambda predictions: float(max(enumerate(predictions),
2744                            key=operator.itemgetter(1))[0]), DoubleType())
2745             aggregatedDataset = aggregatedDataset.withColumn(
2746                 self.getPredictionCol(), labelUDF(aggregatedDataset[accColName]))
2747         return aggregatedDataset.drop(accColName)
2748 
2749     @since("2.0.0")
2750     def copy(self, extra=None):
2751         """
2752         Creates a copy of this instance with a randomly generated uid
2753         and some extra params. This creates a deep copy of the embedded paramMap,
2754         and copies the embedded and extra parameters over.
2755 
2756         :param extra: Extra parameters to copy to the new instance
2757         :return: Copy of this instance
2758         """
2759         if extra is None:
2760             extra = dict()
2761         newModel = Params.copy(self, extra)
2762         newModel.models = [model.copy(extra) for model in self.models]
2763         return newModel
2764 
2765     @classmethod
2766     def _from_java(cls, java_stage):
2767         """
2768         Given a Java OneVsRestModel, create and return a Python wrapper of it.
2769         Used for ML persistence.
2770         """
2771         featuresCol = java_stage.getFeaturesCol()
2772         labelCol = java_stage.getLabelCol()
2773         predictionCol = java_stage.getPredictionCol()
2774         classifier = JavaParams._from_java(java_stage.getClassifier())
2775         models = [JavaParams._from_java(model) for model in java_stage.models()]
2776         py_stage = cls(models=models).setPredictionCol(predictionCol)\
2777             .setFeaturesCol(featuresCol)
2778         py_stage._set(labelCol=labelCol)
2779         if java_stage.isDefined(java_stage.getParam("weightCol")):
2780             py_stage._set(weightCol=java_stage.getWeightCol())
2781         py_stage._set(classifier=classifier)
2782         py_stage._resetUid(java_stage.uid())
2783         return py_stage
2784 
2785     def _to_java(self):
2786         """
2787         Transfer this instance to a Java OneVsRestModel. Used for ML persistence.
2788 
2789         :return: Java object equivalent to this instance.
2790         """
2791         sc = SparkContext._active_spark_context
2792         java_models = [model._to_java() for model in self.models]
2793         java_models_array = JavaWrapper._new_java_array(
2794             java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
2795         metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
2796         _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
2797                                              self.uid, metadata.empty(), java_models_array)
2798         _java_obj.set("classifier", self.getClassifier()._to_java())
2799         _java_obj.set("featuresCol", self.getFeaturesCol())
2800         _java_obj.set("labelCol", self.getLabelCol())
2801         _java_obj.set("predictionCol", self.getPredictionCol())
2802         if (self.isDefined(self.weightCol) and self.getWeightCol()):
2803             _java_obj.set("weightCol", self.getWeightCol())
2804         return _java_obj
2805 
2806 
2807 @inherit_doc
2808 class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
2809                    JavaMLReadable):
2810     """
2811     Factorization Machines learning algorithm for classification.
2812 
2813     solver Supports:
2814 
2815     * gd (normal mini-batch gradient descent)
2816     * adamW (default)
2817 
2818     >>> from pyspark.ml.linalg import Vectors
2819     >>> from pyspark.ml.classification import FMClassifier
2820     >>> df = spark.createDataFrame([
2821     ...     (1.0, Vectors.dense(1.0)),
2822     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
2823     >>> fm = FMClassifier(factorSize=2)
2824     >>> fm.setSeed(11)
2825     FMClassifier...
2826     >>> model = fm.fit(df)
2827     >>> model.getMaxIter()
2828     100
2829     >>> test0 = spark.createDataFrame([
2830     ...     (Vectors.dense(-1.0),),
2831     ...     (Vectors.dense(0.5),),
2832     ...     (Vectors.dense(1.0),),
2833     ...     (Vectors.dense(2.0),)], ["features"])
2834     >>> model.predictRaw(test0.head().features)
2835     DenseVector([22.13..., -22.13...])
2836     >>> model.predictProbability(test0.head().features)
2837     DenseVector([1.0, 0.0])
2838     >>> model.transform(test0).select("features", "probability").show(10, False)
2839     +--------+------------------------------------------+
2840     |features|probability                               |
2841     +--------+------------------------------------------+
2842     |[-1.0]  |[0.9999999997574736,2.425264676902229E-10]|
2843     |[0.5]   |[0.47627851732981163,0.5237214826701884]  |
2844     |[1.0]   |[5.491554426243495E-4,0.9994508445573757] |
2845     |[2.0]   |[2.005766663870645E-10,0.9999999997994233]|
2846     +--------+------------------------------------------+
2847     ...
2848     >>> model.intercept
2849     -7.316665276826291
2850     >>> model.linear
2851     DenseVector([14.8232])
2852     >>> model.factors
2853     DenseMatrix(1, 2, [0.0163, -0.0051], 1)
2854 
2855     .. versionadded:: 3.0.0
2856     """
2857 
2858     @keyword_only
2859     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2860                  probabilityCol="probability", rawPredictionCol="rawPrediction",
2861                  factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2862                  miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2863                  tol=1e-6, solver="adamW", thresholds=None, seed=None):
2864         """
2865         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2866                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
2867                  factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2868                  miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2869                  tol=1e-6, solver="adamW", thresholds=None, seed=None)
2870         """
2871         super(FMClassifier, self).__init__()
2872         self._java_obj = self._new_java_obj(
2873             "org.apache.spark.ml.classification.FMClassifier", self.uid)
2874         self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2875                          miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2876                          tol=1e-6, solver="adamW")
2877         kwargs = self._input_kwargs
2878         self.setParams(**kwargs)
2879 
2880     @keyword_only
2881     @since("3.0.0")
2882     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
2883                   probabilityCol="probability", rawPredictionCol="rawPrediction",
2884                   factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
2885                   miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
2886                   tol=1e-6, solver="adamW", thresholds=None, seed=None):
2887         """
2888         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
2889                   probabilityCol="probability", rawPredictionCol="rawPrediction", \
2890                   factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \
2891                   miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
2892                   tol=1e-6, solver="adamW", thresholds=None, seed=None)
2893         Sets Params for FMClassifier.
2894         """
2895         kwargs = self._input_kwargs
2896         return self._set(**kwargs)
2897 
2898     def _create_model(self, java_model):
2899         return FMClassificationModel(java_model)
2900 
2901     @since("3.0.0")
2902     def setFactorSize(self, value):
2903         """
2904         Sets the value of :py:attr:`factorSize`.
2905         """
2906         return self._set(factorSize=value)
2907 
2908     @since("3.0.0")
2909     def setFitLinear(self, value):
2910         """
2911         Sets the value of :py:attr:`fitLinear`.
2912         """
2913         return self._set(fitLinear=value)
2914 
2915     @since("3.0.0")
2916     def setMiniBatchFraction(self, value):
2917         """
2918         Sets the value of :py:attr:`miniBatchFraction`.
2919         """
2920         return self._set(miniBatchFraction=value)
2921 
2922     @since("3.0.0")
2923     def setInitStd(self, value):
2924         """
2925         Sets the value of :py:attr:`initStd`.
2926         """
2927         return self._set(initStd=value)
2928 
2929     @since("3.0.0")
2930     def setMaxIter(self, value):
2931         """
2932         Sets the value of :py:attr:`maxIter`.
2933         """
2934         return self._set(maxIter=value)
2935 
2936     @since("3.0.0")
2937     def setStepSize(self, value):
2938         """
2939         Sets the value of :py:attr:`stepSize`.
2940         """
2941         return self._set(stepSize=value)
2942 
2943     @since("3.0.0")
2944     def setTol(self, value):
2945         """
2946         Sets the value of :py:attr:`tol`.
2947         """
2948         return self._set(tol=value)
2949 
2950     @since("3.0.0")
2951     def setSolver(self, value):
2952         """
2953         Sets the value of :py:attr:`solver`.
2954         """
2955         return self._set(solver=value)
2956 
2957     @since("3.0.0")
2958     def setSeed(self, value):
2959         """
2960         Sets the value of :py:attr:`seed`.
2961         """
2962         return self._set(seed=value)
2963 
2964     @since("3.0.0")
2965     def setFitIntercept(self, value):
2966         """
2967         Sets the value of :py:attr:`fitIntercept`.
2968         """
2969         return self._set(fitIntercept=value)
2970 
2971     @since("3.0.0")
2972     def setRegParam(self, value):
2973         """
2974         Sets the value of :py:attr:`regParam`.
2975         """
2976         return self._set(regParam=value)
2977 
2978 
2979 class FMClassificationModel(JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
2980                             JavaMLWritable, JavaMLReadable):
2981     """
2982     Model fitted by :class:`FMClassifier`.
2983 
2984     .. versionadded:: 3.0.0
2985     """
2986 
2987     @property
2988     @since("3.0.0")
2989     def intercept(self):
2990         """
2991         Model intercept.
2992         """
2993         return self._call_java("intercept")
2994 
2995     @property
2996     @since("3.0.0")
2997     def linear(self):
2998         """
2999         Model linear term.
3000         """
3001         return self._call_java("linear")
3002 
3003     @property
3004     @since("3.0.0")
3005     def factors(self):
3006         """
3007         Model factor term.
3008         """
3009         return self._call_java("factors")
3010 
3011 
3012 if __name__ == "__main__":
3013     import doctest
3014     import pyspark.ml.classification
3015     from pyspark.sql import SparkSession
3016     globs = pyspark.ml.classification.__dict__.copy()
3017     # The small batch size here ensures that we see multiple batches,
3018     # even in these small test examples:
3019     spark = SparkSession.builder\
3020         .master("local[2]")\
3021         .appName("ml.classification tests")\
3022         .getOrCreate()
3023     sc = spark.sparkContext
3024     globs['sc'] = sc
3025     globs['spark'] = spark
3026     import tempfile
3027     temp_path = tempfile.mkdtemp()
3028     globs['temp_path'] = temp_path
3029     try:
3030         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
3031         spark.stop()
3032     finally:
3033         from shutil import rmtree
3034         try:
3035             rmtree(temp_path)
3036         except OSError:
3037             pass
3038     if failure_count:
3039         sys.exit(-1)