0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 if sys.version > '3':
0020 basestring = str
0021
0022 from pyspark import since, keyword_only, SparkContext
0023 from pyspark.rdd import ignore_unicode_prefix
0024 from pyspark.ml.linalg import _convert_to_vector
0025 from pyspark.ml.param.shared import *
0026 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
0027 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
0028 from pyspark.ml.common import inherit_doc
0029
0030 __all__ = ['Binarizer',
0031 'BucketedRandomProjectionLSH', 'BucketedRandomProjectionLSHModel',
0032 'Bucketizer',
0033 'ChiSqSelector', 'ChiSqSelectorModel',
0034 'CountVectorizer', 'CountVectorizerModel',
0035 'DCT',
0036 'ElementwiseProduct',
0037 'FeatureHasher',
0038 'HashingTF',
0039 'IDF', 'IDFModel',
0040 'Imputer', 'ImputerModel',
0041 'IndexToString',
0042 'Interaction',
0043 'MaxAbsScaler', 'MaxAbsScalerModel',
0044 'MinHashLSH', 'MinHashLSHModel',
0045 'MinMaxScaler', 'MinMaxScalerModel',
0046 'NGram',
0047 'Normalizer',
0048 'OneHotEncoder', 'OneHotEncoderModel',
0049 'PCA', 'PCAModel',
0050 'PolynomialExpansion',
0051 'QuantileDiscretizer',
0052 'RobustScaler', 'RobustScalerModel',
0053 'RegexTokenizer',
0054 'RFormula', 'RFormulaModel',
0055 'SQLTransformer',
0056 'StandardScaler', 'StandardScalerModel',
0057 'StopWordsRemover',
0058 'StringIndexer', 'StringIndexerModel',
0059 'Tokenizer',
0060 'VectorAssembler',
0061 'VectorIndexer', 'VectorIndexerModel',
0062 'VectorSizeHint',
0063 'VectorSlicer',
0064 'Word2Vec', 'Word2VecModel']
0065
0066
0067 @inherit_doc
0068 class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOutputCol,
0069 HasInputCols, HasOutputCols, JavaMLReadable, JavaMLWritable):
0070 """
0071 Binarize a column of continuous features given a threshold. Since 3.0.0,
0072 :py:class:`Binarize` can map multiple columns at once by setting the :py:attr:`inputCols`
0073 parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters
0074 are set, an Exception will be thrown. The :py:attr:`threshold` parameter is used for
0075 single column usage, and :py:attr:`thresholds` is for multiple columns.
0076
0077 >>> df = spark.createDataFrame([(0.5,)], ["values"])
0078 >>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features")
0079 >>> binarizer.setThreshold(1.0)
0080 Binarizer...
0081 >>> binarizer.setInputCol("values")
0082 Binarizer...
0083 >>> binarizer.setOutputCol("features")
0084 Binarizer...
0085 >>> binarizer.transform(df).head().features
0086 0.0
0087 >>> binarizer.setParams(outputCol="freqs").transform(df).head().freqs
0088 0.0
0089 >>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"}
0090 >>> binarizer.transform(df, params).head().vector
0091 1.0
0092 >>> binarizerPath = temp_path + "/binarizer"
0093 >>> binarizer.save(binarizerPath)
0094 >>> loadedBinarizer = Binarizer.load(binarizerPath)
0095 >>> loadedBinarizer.getThreshold() == binarizer.getThreshold()
0096 True
0097 >>> df2 = spark.createDataFrame([(0.5, 0.3)], ["values1", "values2"])
0098 >>> binarizer2 = Binarizer(thresholds=[0.0, 1.0])
0099 >>> binarizer2.setInputCols(["values1", "values2"]).setOutputCols(["output1", "output2"])
0100 Binarizer...
0101 >>> binarizer2.transform(df2).show()
0102 +-------+-------+-------+-------+
0103 |values1|values2|output1|output2|
0104 +-------+-------+-------+-------+
0105 | 0.5| 0.3| 1.0| 0.0|
0106 +-------+-------+-------+-------+
0107 ...
0108
0109 .. versionadded:: 1.4.0
0110 """
0111
0112 threshold = Param(Params._dummy(), "threshold",
0113 "Param for threshold used to binarize continuous features. " +
0114 "The features greater than the threshold will be binarized to 1.0. " +
0115 "The features equal to or less than the threshold will be binarized to 0.0",
0116 typeConverter=TypeConverters.toFloat)
0117 thresholds = Param(Params._dummy(), "thresholds",
0118 "Param for array of threshold used to binarize continuous features. " +
0119 "This is for multiple columns input. If transforming multiple columns " +
0120 "and thresholds is not set, but threshold is set, then threshold will " +
0121 "be applied across all columns.",
0122 typeConverter=TypeConverters.toListFloat)
0123
0124 @keyword_only
0125 def __init__(self, threshold=0.0, inputCol=None, outputCol=None, thresholds=None,
0126 inputCols=None, outputCols=None):
0127 """
0128 __init__(self, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
0129 inputCols=None, outputCols=None)
0130 """
0131 super(Binarizer, self).__init__()
0132 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid)
0133 self._setDefault(threshold=0.0)
0134 kwargs = self._input_kwargs
0135 self.setParams(**kwargs)
0136
0137 @keyword_only
0138 @since("1.4.0")
0139 def setParams(self, threshold=0.0, inputCol=None, outputCol=None, thresholds=None,
0140 inputCols=None, outputCols=None):
0141 """
0142 setParams(self, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
0143 inputCols=None, outputCols=None)
0144 Sets params for this Binarizer.
0145 """
0146 kwargs = self._input_kwargs
0147 return self._set(**kwargs)
0148
0149 @since("1.4.0")
0150 def setThreshold(self, value):
0151 """
0152 Sets the value of :py:attr:`threshold`.
0153 """
0154 return self._set(threshold=value)
0155
0156 @since("3.0.0")
0157 def setThresholds(self, value):
0158 """
0159 Sets the value of :py:attr:`thresholds`.
0160 """
0161 return self._set(thresholds=value)
0162
0163 def setInputCol(self, value):
0164 """
0165 Sets the value of :py:attr:`inputCol`.
0166 """
0167 return self._set(inputCol=value)
0168
0169 @since("3.0.0")
0170 def setInputCols(self, value):
0171 """
0172 Sets the value of :py:attr:`inputCols`.
0173 """
0174 return self._set(inputCols=value)
0175
0176 def setOutputCol(self, value):
0177 """
0178 Sets the value of :py:attr:`outputCol`.
0179 """
0180 return self._set(outputCol=value)
0181
0182 @since("3.0.0")
0183 def setOutputCols(self, value):
0184 """
0185 Sets the value of :py:attr:`outputCols`.
0186 """
0187 return self._set(outputCols=value)
0188
0189
0190 class _LSHParams(HasInputCol, HasOutputCol):
0191 """
0192 Mixin for Locality Sensitive Hashing (LSH) algorithm parameters.
0193 """
0194
0195 numHashTables = Param(Params._dummy(), "numHashTables", "number of hash tables, where " +
0196 "increasing number of hash tables lowers the false negative rate, " +
0197 "and decreasing it improves the running performance.",
0198 typeConverter=TypeConverters.toInt)
0199
0200 def getNumHashTables(self):
0201 """
0202 Gets the value of numHashTables or its default value.
0203 """
0204 return self.getOrDefault(self.numHashTables)
0205
0206
0207 class _LSH(JavaEstimator, _LSHParams, JavaMLReadable, JavaMLWritable):
0208 """
0209 Mixin for Locality Sensitive Hashing (LSH).
0210 """
0211
0212 def setNumHashTables(self, value):
0213 """
0214 Sets the value of :py:attr:`numHashTables`.
0215 """
0216 return self._set(numHashTables=value)
0217
0218 def setInputCol(self, value):
0219 """
0220 Sets the value of :py:attr:`inputCol`.
0221 """
0222 return self._set(inputCol=value)
0223
0224 def setOutputCol(self, value):
0225 """
0226 Sets the value of :py:attr:`outputCol`.
0227 """
0228 return self._set(outputCol=value)
0229
0230
0231 class _LSHModel(JavaModel, _LSHParams):
0232 """
0233 Mixin for Locality Sensitive Hashing (LSH) models.
0234 """
0235
0236 def setInputCol(self, value):
0237 """
0238 Sets the value of :py:attr:`inputCol`.
0239 """
0240 return self._set(inputCol=value)
0241
0242 def setOutputCol(self, value):
0243 """
0244 Sets the value of :py:attr:`outputCol`.
0245 """
0246 return self._set(outputCol=value)
0247
0248 def approxNearestNeighbors(self, dataset, key, numNearestNeighbors, distCol="distCol"):
0249 """
0250 Given a large dataset and an item, approximately find at most k items which have the
0251 closest distance to the item. If the :py:attr:`outputCol` is missing, the method will
0252 transform the data; if the :py:attr:`outputCol` exists, it will use that. This allows
0253 caching of the transformed data when necessary.
0254
0255 .. note:: This method is experimental and will likely change behavior in the next release.
0256
0257 :param dataset: The dataset to search for nearest neighbors of the key.
0258 :param key: Feature vector representing the item to search for.
0259 :param numNearestNeighbors: The maximum number of nearest neighbors.
0260 :param distCol: Output column for storing the distance between each result row and the key.
0261 Use "distCol" as default value if it's not specified.
0262 :return: A dataset containing at most k items closest to the key. A column "distCol" is
0263 added to show the distance between each row and the key.
0264 """
0265 return self._call_java("approxNearestNeighbors", dataset, key, numNearestNeighbors,
0266 distCol)
0267
0268 def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol"):
0269 """
0270 Join two datasets to approximately find all pairs of rows whose distance are smaller than
0271 the threshold. If the :py:attr:`outputCol` is missing, the method will transform the data;
0272 if the :py:attr:`outputCol` exists, it will use that. This allows caching of the
0273 transformed data when necessary.
0274
0275 :param datasetA: One of the datasets to join.
0276 :param datasetB: Another dataset to join.
0277 :param threshold: The threshold for the distance of row pairs.
0278 :param distCol: Output column for storing the distance between each pair of rows. Use
0279 "distCol" as default value if it's not specified.
0280 :return: A joined dataset containing pairs of rows. The original rows are in columns
0281 "datasetA" and "datasetB", and a column "distCol" is added to show the distance
0282 between each pair.
0283 """
0284 threshold = TypeConverters.toFloat(threshold)
0285 return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol)
0286
0287
0288 class _BucketedRandomProjectionLSHParams():
0289 """
0290 Params for :py:class:`BucketedRandomProjectionLSH` and
0291 :py:class:`BucketedRandomProjectionLSHModel`.
0292
0293 .. versionadded:: 3.0.0
0294 """
0295
0296 bucketLength = Param(Params._dummy(), "bucketLength", "the length of each hash bucket, " +
0297 "a larger bucket lowers the false negative rate.",
0298 typeConverter=TypeConverters.toFloat)
0299
0300 @since("2.2.0")
0301 def getBucketLength(self):
0302 """
0303 Gets the value of bucketLength or its default value.
0304 """
0305 return self.getOrDefault(self.bucketLength)
0306
0307
0308 @inherit_doc
0309 class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
0310 HasSeed, JavaMLReadable, JavaMLWritable):
0311 """
0312 LSH class for Euclidean distance metrics.
0313 The input is dense or sparse vectors, each of which represents a point in the Euclidean
0314 distance space. The output will be vectors of configurable dimension. Hash values in the same
0315 dimension are calculated by the same hash function.
0316
0317 .. seealso:: `Stable Distributions
0318 <https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions>`_
0319 .. seealso:: `Hashing for Similarity Search: A Survey <https://arxiv.org/abs/1408.2927>`_
0320
0321 >>> from pyspark.ml.linalg import Vectors
0322 >>> from pyspark.sql.functions import col
0323 >>> data = [(0, Vectors.dense([-1.0, -1.0 ]),),
0324 ... (1, Vectors.dense([-1.0, 1.0 ]),),
0325 ... (2, Vectors.dense([1.0, -1.0 ]),),
0326 ... (3, Vectors.dense([1.0, 1.0]),)]
0327 >>> df = spark.createDataFrame(data, ["id", "features"])
0328 >>> brp = BucketedRandomProjectionLSH()
0329 >>> brp.setInputCol("features")
0330 BucketedRandomProjectionLSH...
0331 >>> brp.setOutputCol("hashes")
0332 BucketedRandomProjectionLSH...
0333 >>> brp.setSeed(12345)
0334 BucketedRandomProjectionLSH...
0335 >>> brp.setBucketLength(1.0)
0336 BucketedRandomProjectionLSH...
0337 >>> model = brp.fit(df)
0338 >>> model.getBucketLength()
0339 1.0
0340 >>> model.setOutputCol("hashes")
0341 BucketedRandomProjectionLSHModel...
0342 >>> model.transform(df).head()
0343 Row(id=0, features=DenseVector([-1.0, -1.0]), hashes=[DenseVector([-1.0])])
0344 >>> data2 = [(4, Vectors.dense([2.0, 2.0 ]),),
0345 ... (5, Vectors.dense([2.0, 3.0 ]),),
0346 ... (6, Vectors.dense([3.0, 2.0 ]),),
0347 ... (7, Vectors.dense([3.0, 3.0]),)]
0348 >>> df2 = spark.createDataFrame(data2, ["id", "features"])
0349 >>> model.approxNearestNeighbors(df2, Vectors.dense([1.0, 2.0]), 1).collect()
0350 [Row(id=4, features=DenseVector([2.0, 2.0]), hashes=[DenseVector([1.0])], distCol=1.0)]
0351 >>> model.approxSimilarityJoin(df, df2, 3.0, distCol="EuclideanDistance").select(
0352 ... col("datasetA.id").alias("idA"),
0353 ... col("datasetB.id").alias("idB"),
0354 ... col("EuclideanDistance")).show()
0355 +---+---+-----------------+
0356 |idA|idB|EuclideanDistance|
0357 +---+---+-----------------+
0358 | 3| 6| 2.23606797749979|
0359 +---+---+-----------------+
0360 ...
0361 >>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select(
0362 ... col("datasetA.id").alias("idA"),
0363 ... col("datasetB.id").alias("idB"),
0364 ... col("EuclideanDistance")).show()
0365 +---+---+-----------------+
0366 |idA|idB|EuclideanDistance|
0367 +---+---+-----------------+
0368 | 3| 6| 2.23606797749979|
0369 +---+---+-----------------+
0370 ...
0371 >>> brpPath = temp_path + "/brp"
0372 >>> brp.save(brpPath)
0373 >>> brp2 = BucketedRandomProjectionLSH.load(brpPath)
0374 >>> brp2.getBucketLength() == brp.getBucketLength()
0375 True
0376 >>> modelPath = temp_path + "/brp-model"
0377 >>> model.save(modelPath)
0378 >>> model2 = BucketedRandomProjectionLSHModel.load(modelPath)
0379 >>> model.transform(df).head().hashes == model2.transform(df).head().hashes
0380 True
0381
0382 .. versionadded:: 2.2.0
0383 """
0384
0385 @keyword_only
0386 def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1,
0387 bucketLength=None):
0388 """
0389 __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
0390 bucketLength=None)
0391 """
0392 super(BucketedRandomProjectionLSH, self).__init__()
0393 self._java_obj = \
0394 self._new_java_obj("org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid)
0395 self._setDefault(numHashTables=1)
0396 kwargs = self._input_kwargs
0397 self.setParams(**kwargs)
0398
0399 @keyword_only
0400 @since("2.2.0")
0401 def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1,
0402 bucketLength=None):
0403 """
0404 setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
0405 bucketLength=None)
0406 Sets params for this BucketedRandomProjectionLSH.
0407 """
0408 kwargs = self._input_kwargs
0409 return self._set(**kwargs)
0410
0411 @since("2.2.0")
0412 def setBucketLength(self, value):
0413 """
0414 Sets the value of :py:attr:`bucketLength`.
0415 """
0416 return self._set(bucketLength=value)
0417
0418 def setSeed(self, value):
0419 """
0420 Sets the value of :py:attr:`seed`.
0421 """
0422 return self._set(seed=value)
0423
0424 def _create_model(self, java_model):
0425 return BucketedRandomProjectionLSHModel(java_model)
0426
0427
0428 class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHParams,
0429 JavaMLReadable, JavaMLWritable):
0430 r"""
0431 Model fitted by :py:class:`BucketedRandomProjectionLSH`, where multiple random vectors are
0432 stored. The vectors are normalized to be unit vectors and each vector is used in a hash
0433 function: :math:`h_i(x) = floor(r_i \cdot x / bucketLength)` where :math:`r_i` is the
0434 i-th random unit vector. The number of buckets will be `(max L2 norm of input vectors) /
0435 bucketLength`.
0436
0437 .. versionadded:: 2.2.0
0438 """
0439
0440
0441 @inherit_doc
0442 class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
0443 HasHandleInvalid, JavaMLReadable, JavaMLWritable):
0444 """
0445 Maps a column of continuous features to a column of feature buckets. Since 3.0.0,
0446 :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
0447 parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters
0448 are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single
0449 column usage, and :py:attr:`splitsArray` is for multiple columns.
0450
0451 >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
0452 ... (float("nan"), 1.0), (float("nan"), 0.0)]
0453 >>> df = spark.createDataFrame(values, ["values1", "values2"])
0454 >>> bucketizer = Bucketizer()
0455 >>> bucketizer.setSplits([-float("inf"), 0.5, 1.4, float("inf")])
0456 Bucketizer...
0457 >>> bucketizer.setInputCol("values1")
0458 Bucketizer...
0459 >>> bucketizer.setOutputCol("buckets")
0460 Bucketizer...
0461 >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
0462 >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))
0463 >>> bucketed.show(truncate=False)
0464 +-------+-------+
0465 |values1|buckets|
0466 +-------+-------+
0467 |0.1 |0.0 |
0468 |0.4 |0.0 |
0469 |1.2 |1.0 |
0470 |1.5 |2.0 |
0471 |NaN |3.0 |
0472 |NaN |3.0 |
0473 +-------+-------+
0474 ...
0475 >>> bucketizer.setParams(outputCol="b").transform(df).head().b
0476 0.0
0477 >>> bucketizerPath = temp_path + "/bucketizer"
0478 >>> bucketizer.save(bucketizerPath)
0479 >>> loadedBucketizer = Bucketizer.load(bucketizerPath)
0480 >>> loadedBucketizer.getSplits() == bucketizer.getSplits()
0481 True
0482 >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
0483 >>> len(bucketed)
0484 4
0485 >>> bucketizer2 = Bucketizer(splitsArray=
0486 ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]],
0487 ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])
0488 >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)
0489 >>> bucketed2.show(truncate=False)
0490 +-------+-------+--------+--------+
0491 |values1|values2|buckets1|buckets2|
0492 +-------+-------+--------+--------+
0493 |0.1 |0.0 |0.0 |0.0 |
0494 |0.4 |1.0 |0.0 |1.0 |
0495 |1.2 |1.3 |1.0 |1.0 |
0496 |1.5 |NaN |2.0 |2.0 |
0497 |NaN |1.0 |3.0 |1.0 |
0498 |NaN |0.0 |3.0 |0.0 |
0499 +-------+-------+--------+--------+
0500 ...
0501
0502 .. versionadded:: 1.4.0
0503 """
0504
0505 splits = \
0506 Param(Params._dummy(), "splits",
0507 "Split points for mapping continuous features into buckets. With n+1 splits, " +
0508 "there are n buckets. A bucket defined by splits x,y holds values in the " +
0509 "range [x,y) except the last bucket, which also includes y. The splits " +
0510 "should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
0511 "explicitly provided to cover all Double values; otherwise, values outside the " +
0512 "splits specified will be treated as errors.",
0513 typeConverter=TypeConverters.toListFloat)
0514
0515 handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries "
0516 "containing NaN values. Values outside the splits will always be treated "
0517 "as errors. Options are 'skip' (filter out rows with invalid values), " +
0518 "'error' (throw an error), or 'keep' (keep invalid values in a " +
0519 "special additional bucket). Note that in the multiple column " +
0520 "case, the invalid handling is applied to all columns. That said " +
0521 "for 'error' it will throw an error if any invalids are found in " +
0522 "any column, for 'skip' it will skip rows with any invalids in " +
0523 "any columns, etc.",
0524 typeConverter=TypeConverters.toString)
0525
0526 splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " +
0527 "continuous features into buckets for multiple columns. For each input " +
0528 "column, with n+1 splits, there are n buckets. A bucket defined by " +
0529 "splits x,y holds values in the range [x,y) except the last bucket, " +
0530 "which also includes y. The splits should be of length >= 3 and " +
0531 "strictly increasing. Values at -inf, inf must be explicitly provided " +
0532 "to cover all Double values; otherwise, values outside the splits " +
0533 "specified will be treated as errors.",
0534 typeConverter=TypeConverters.toListListFloat)
0535
0536 @keyword_only
0537 def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
0538 splitsArray=None, inputCols=None, outputCols=None):
0539 """
0540 __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
0541 splitsArray=None, inputCols=None, outputCols=None)
0542 """
0543 super(Bucketizer, self).__init__()
0544 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
0545 self._setDefault(handleInvalid="error")
0546 kwargs = self._input_kwargs
0547 self.setParams(**kwargs)
0548
0549 @keyword_only
0550 @since("1.4.0")
0551 def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
0552 splitsArray=None, inputCols=None, outputCols=None):
0553 """
0554 setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
0555 splitsArray=None, inputCols=None, outputCols=None)
0556 Sets params for this Bucketizer.
0557 """
0558 kwargs = self._input_kwargs
0559 return self._set(**kwargs)
0560
0561 @since("1.4.0")
0562 def setSplits(self, value):
0563 """
0564 Sets the value of :py:attr:`splits`.
0565 """
0566 return self._set(splits=value)
0567
0568 @since("1.4.0")
0569 def getSplits(self):
0570 """
0571 Gets the value of threshold or its default value.
0572 """
0573 return self.getOrDefault(self.splits)
0574
0575 @since("3.0.0")
0576 def setSplitsArray(self, value):
0577 """
0578 Sets the value of :py:attr:`splitsArray`.
0579 """
0580 return self._set(splitsArray=value)
0581
0582 @since("3.0.0")
0583 def getSplitsArray(self):
0584 """
0585 Gets the array of split points or its default value.
0586 """
0587 return self.getOrDefault(self.splitsArray)
0588
0589 def setInputCol(self, value):
0590 """
0591 Sets the value of :py:attr:`inputCol`.
0592 """
0593 return self._set(inputCol=value)
0594
0595 @since("3.0.0")
0596 def setInputCols(self, value):
0597 """
0598 Sets the value of :py:attr:`inputCols`.
0599 """
0600 return self._set(inputCols=value)
0601
0602 def setOutputCol(self, value):
0603 """
0604 Sets the value of :py:attr:`outputCol`.
0605 """
0606 return self._set(outputCol=value)
0607
0608 @since("3.0.0")
0609 def setOutputCols(self, value):
0610 """
0611 Sets the value of :py:attr:`outputCols`.
0612 """
0613 return self._set(outputCols=value)
0614
0615 def setHandleInvalid(self, value):
0616 """
0617 Sets the value of :py:attr:`handleInvalid`.
0618 """
0619 return self._set(handleInvalid=value)
0620
0621
0622 class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
0623 """
0624 Params for :py:class:`CountVectorizer` and :py:class:`CountVectorizerModel`.
0625 """
0626
0627 minTF = Param(
0628 Params._dummy(), "minTF", "Filter to ignore rare words in" +
0629 " a document. For each document, terms with frequency/count less than the given" +
0630 " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
0631 " times the term must appear in the document); if this is a double in [0,1), then this " +
0632 "specifies a fraction (out of the document's token count). Note that the parameter is " +
0633 "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
0634 typeConverter=TypeConverters.toFloat)
0635 minDF = Param(
0636 Params._dummy(), "minDF", "Specifies the minimum number of" +
0637 " different documents a term must appear in to be included in the vocabulary." +
0638 " If this is an integer >= 1, this specifies the number of documents the term must" +
0639 " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
0640 " Default 1.0", typeConverter=TypeConverters.toFloat)
0641 maxDF = Param(
0642 Params._dummy(), "maxDF", "Specifies the maximum number of" +
0643 " different documents a term could appear in to be included in the vocabulary." +
0644 " A term that appears more than the threshold will be ignored. If this is an" +
0645 " integer >= 1, this specifies the maximum number of documents the term could appear in;" +
0646 " if this is a double in [0,1), then this specifies the maximum" +
0647 " fraction of documents the term could appear in." +
0648 " Default (2^63) - 1", typeConverter=TypeConverters.toFloat)
0649 vocabSize = Param(
0650 Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
0651 typeConverter=TypeConverters.toInt)
0652 binary = Param(
0653 Params._dummy(), "binary", "Binary toggle to control the output vector values." +
0654 " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
0655 " for discrete probabilistic models that model binary events rather than integer counts." +
0656 " Default False", typeConverter=TypeConverters.toBoolean)
0657
0658 def __init__(self, *args):
0659 super(_CountVectorizerParams, self).__init__(*args)
0660 self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False)
0661
0662 @since("1.6.0")
0663 def getMinTF(self):
0664 """
0665 Gets the value of minTF or its default value.
0666 """
0667 return self.getOrDefault(self.minTF)
0668
0669 @since("1.6.0")
0670 def getMinDF(self):
0671 """
0672 Gets the value of minDF or its default value.
0673 """
0674 return self.getOrDefault(self.minDF)
0675
0676 @since("2.4.0")
0677 def getMaxDF(self):
0678 """
0679 Gets the value of maxDF or its default value.
0680 """
0681 return self.getOrDefault(self.maxDF)
0682
0683 @since("1.6.0")
0684 def getVocabSize(self):
0685 """
0686 Gets the value of vocabSize or its default value.
0687 """
0688 return self.getOrDefault(self.vocabSize)
0689
0690 @since("2.0.0")
0691 def getBinary(self):
0692 """
0693 Gets the value of binary or its default value.
0694 """
0695 return self.getOrDefault(self.binary)
0696
0697
0698 @inherit_doc
0699 class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
0700 """
0701 Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.
0702
0703 >>> df = spark.createDataFrame(
0704 ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],
0705 ... ["label", "raw"])
0706 >>> cv = CountVectorizer()
0707 >>> cv.setInputCol("raw")
0708 CountVectorizer...
0709 >>> cv.setOutputCol("vectors")
0710 CountVectorizer...
0711 >>> model = cv.fit(df)
0712 >>> model.setInputCol("raw")
0713 CountVectorizerModel...
0714 >>> model.transform(df).show(truncate=False)
0715 +-----+---------------+-------------------------+
0716 |label|raw |vectors |
0717 +-----+---------------+-------------------------+
0718 |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
0719 |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
0720 +-----+---------------+-------------------------+
0721 ...
0722 >>> sorted(model.vocabulary) == ['a', 'b', 'c']
0723 True
0724 >>> countVectorizerPath = temp_path + "/count-vectorizer"
0725 >>> cv.save(countVectorizerPath)
0726 >>> loadedCv = CountVectorizer.load(countVectorizerPath)
0727 >>> loadedCv.getMinDF() == cv.getMinDF()
0728 True
0729 >>> loadedCv.getMinTF() == cv.getMinTF()
0730 True
0731 >>> loadedCv.getVocabSize() == cv.getVocabSize()
0732 True
0733 >>> modelPath = temp_path + "/count-vectorizer-model"
0734 >>> model.save(modelPath)
0735 >>> loadedModel = CountVectorizerModel.load(modelPath)
0736 >>> loadedModel.vocabulary == model.vocabulary
0737 True
0738 >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
0739 ... inputCol="raw", outputCol="vectors")
0740 >>> fromVocabModel.transform(df).show(truncate=False)
0741 +-----+---------------+-------------------------+
0742 |label|raw |vectors |
0743 +-----+---------------+-------------------------+
0744 |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
0745 |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
0746 +-----+---------------+-------------------------+
0747 ...
0748
0749 .. versionadded:: 1.6.0
0750 """
0751
0752 @keyword_only
0753 def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,
0754 inputCol=None, outputCol=None):
0755 """
0756 __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\
0757 inputCol=None,outputCol=None)
0758 """
0759 super(CountVectorizer, self).__init__()
0760 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
0761 self.uid)
0762 kwargs = self._input_kwargs
0763 self.setParams(**kwargs)
0764
0765 @keyword_only
0766 @since("1.6.0")
0767 def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,
0768 inputCol=None, outputCol=None):
0769 """
0770 setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\
0771 inputCol=None, outputCol=None)
0772 Set the params for the CountVectorizer
0773 """
0774 kwargs = self._input_kwargs
0775 return self._set(**kwargs)
0776
0777 @since("1.6.0")
0778 def setMinTF(self, value):
0779 """
0780 Sets the value of :py:attr:`minTF`.
0781 """
0782 return self._set(minTF=value)
0783
0784 @since("1.6.0")
0785 def setMinDF(self, value):
0786 """
0787 Sets the value of :py:attr:`minDF`.
0788 """
0789 return self._set(minDF=value)
0790
0791 @since("2.4.0")
0792 def setMaxDF(self, value):
0793 """
0794 Sets the value of :py:attr:`maxDF`.
0795 """
0796 return self._set(maxDF=value)
0797
0798 @since("1.6.0")
0799 def setVocabSize(self, value):
0800 """
0801 Sets the value of :py:attr:`vocabSize`.
0802 """
0803 return self._set(vocabSize=value)
0804
0805 @since("2.0.0")
0806 def setBinary(self, value):
0807 """
0808 Sets the value of :py:attr:`binary`.
0809 """
0810 return self._set(binary=value)
0811
0812 def setInputCol(self, value):
0813 """
0814 Sets the value of :py:attr:`inputCol`.
0815 """
0816 return self._set(inputCol=value)
0817
0818 def setOutputCol(self, value):
0819 """
0820 Sets the value of :py:attr:`outputCol`.
0821 """
0822 return self._set(outputCol=value)
0823
0824 def _create_model(self, java_model):
0825 return CountVectorizerModel(java_model)
0826
0827
0828 @inherit_doc
0829 class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
0830 """
0831 Model fitted by :py:class:`CountVectorizer`.
0832
0833 .. versionadded:: 1.6.0
0834 """
0835
0836 @since("3.0.0")
0837 def setInputCol(self, value):
0838 """
0839 Sets the value of :py:attr:`inputCol`.
0840 """
0841 return self._set(inputCol=value)
0842
0843 @since("3.0.0")
0844 def setOutputCol(self, value):
0845 """
0846 Sets the value of :py:attr:`outputCol`.
0847 """
0848 return self._set(outputCol=value)
0849
0850 @classmethod
0851 @since("2.4.0")
0852 def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None):
0853 """
0854 Construct the model directly from a vocabulary list of strings,
0855 requires an active SparkContext.
0856 """
0857 sc = SparkContext._active_spark_context
0858 java_class = sc._gateway.jvm.java.lang.String
0859 jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
0860 model = CountVectorizerModel._create_from_java_class(
0861 "org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
0862 model.setInputCol(inputCol)
0863 if outputCol is not None:
0864 model.setOutputCol(outputCol)
0865 if minTF is not None:
0866 model.setMinTF(minTF)
0867 if binary is not None:
0868 model.setBinary(binary)
0869 model._set(vocabSize=len(vocabulary))
0870 return model
0871
0872 @property
0873 @since("1.6.0")
0874 def vocabulary(self):
0875 """
0876 An array of terms in the vocabulary.
0877 """
0878 return self._call_java("vocabulary")
0879
0880 @since("2.4.0")
0881 def setMinTF(self, value):
0882 """
0883 Sets the value of :py:attr:`minTF`.
0884 """
0885 return self._set(minTF=value)
0886
0887 @since("2.4.0")
0888 def setBinary(self, value):
0889 """
0890 Sets the value of :py:attr:`binary`.
0891 """
0892 return self._set(binary=value)
0893
0894
0895 @inherit_doc
0896 class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
0897 """
0898 A feature transformer that takes the 1D discrete cosine transform
0899 of a real vector. No zero padding is performed on the input vector.
0900 It returns a real vector of the same length representing the DCT.
0901 The return vector is scaled such that the transform matrix is
0902 unitary (aka scaled DCT-II).
0903
0904 .. seealso:: `More information on Wikipedia
0905 <https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia>`_.
0906
0907 >>> from pyspark.ml.linalg import Vectors
0908 >>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
0909 >>> dct = DCT( )
0910 >>> dct.setInverse(False)
0911 DCT...
0912 >>> dct.setInputCol("vec")
0913 DCT...
0914 >>> dct.setOutputCol("resultVec")
0915 DCT...
0916 >>> df2 = dct.transform(df1)
0917 >>> df2.head().resultVec
0918 DenseVector([10.969..., -0.707..., -2.041...])
0919 >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2)
0920 >>> df3.head().origVec
0921 DenseVector([5.0, 8.0, 6.0])
0922 >>> dctPath = temp_path + "/dct"
0923 >>> dct.save(dctPath)
0924 >>> loadedDtc = DCT.load(dctPath)
0925 >>> loadedDtc.getInverse()
0926 False
0927
0928 .. versionadded:: 1.6.0
0929 """
0930
0931 inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " +
0932 "default False.", typeConverter=TypeConverters.toBoolean)
0933
0934 @keyword_only
0935 def __init__(self, inverse=False, inputCol=None, outputCol=None):
0936 """
0937 __init__(self, inverse=False, inputCol=None, outputCol=None)
0938 """
0939 super(DCT, self).__init__()
0940 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid)
0941 self._setDefault(inverse=False)
0942 kwargs = self._input_kwargs
0943 self.setParams(**kwargs)
0944
0945 @keyword_only
0946 @since("1.6.0")
0947 def setParams(self, inverse=False, inputCol=None, outputCol=None):
0948 """
0949 setParams(self, inverse=False, inputCol=None, outputCol=None)
0950 Sets params for this DCT.
0951 """
0952 kwargs = self._input_kwargs
0953 return self._set(**kwargs)
0954
0955 @since("1.6.0")
0956 def setInverse(self, value):
0957 """
0958 Sets the value of :py:attr:`inverse`.
0959 """
0960 return self._set(inverse=value)
0961
0962 @since("1.6.0")
0963 def getInverse(self):
0964 """
0965 Gets the value of inverse or its default value.
0966 """
0967 return self.getOrDefault(self.inverse)
0968
0969 def setInputCol(self, value):
0970 """
0971 Sets the value of :py:attr:`inputCol`.
0972 """
0973 return self._set(inputCol=value)
0974
0975 def setOutputCol(self, value):
0976 """
0977 Sets the value of :py:attr:`outputCol`.
0978 """
0979 return self._set(outputCol=value)
0980
0981
0982 @inherit_doc
0983 class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
0984 JavaMLWritable):
0985 """
0986 Outputs the Hadamard product (i.e., the element-wise product) of each input vector
0987 with a provided "weight" vector. In other words, it scales each column of the dataset
0988 by a scalar multiplier.
0989
0990 >>> from pyspark.ml.linalg import Vectors
0991 >>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"])
0992 >>> ep = ElementwiseProduct()
0993 >>> ep.setScalingVec(Vectors.dense([1.0, 2.0, 3.0]))
0994 ElementwiseProduct...
0995 >>> ep.setInputCol("values")
0996 ElementwiseProduct...
0997 >>> ep.setOutputCol("eprod")
0998 ElementwiseProduct...
0999 >>> ep.transform(df).head().eprod
1000 DenseVector([2.0, 2.0, 9.0])
1001 >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod
1002 DenseVector([4.0, 3.0, 15.0])
1003 >>> elementwiseProductPath = temp_path + "/elementwise-product"
1004 >>> ep.save(elementwiseProductPath)
1005 >>> loadedEp = ElementwiseProduct.load(elementwiseProductPath)
1006 >>> loadedEp.getScalingVec() == ep.getScalingVec()
1007 True
1008
1009 .. versionadded:: 1.5.0
1010 """
1011
1012 scalingVec = Param(Params._dummy(), "scalingVec", "Vector for hadamard product.",
1013 typeConverter=TypeConverters.toVector)
1014
1015 @keyword_only
1016 def __init__(self, scalingVec=None, inputCol=None, outputCol=None):
1017 """
1018 __init__(self, scalingVec=None, inputCol=None, outputCol=None)
1019 """
1020 super(ElementwiseProduct, self).__init__()
1021 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct",
1022 self.uid)
1023 kwargs = self._input_kwargs
1024 self.setParams(**kwargs)
1025
1026 @keyword_only
1027 @since("1.5.0")
1028 def setParams(self, scalingVec=None, inputCol=None, outputCol=None):
1029 """
1030 setParams(self, scalingVec=None, inputCol=None, outputCol=None)
1031 Sets params for this ElementwiseProduct.
1032 """
1033 kwargs = self._input_kwargs
1034 return self._set(**kwargs)
1035
1036 @since("2.0.0")
1037 def setScalingVec(self, value):
1038 """
1039 Sets the value of :py:attr:`scalingVec`.
1040 """
1041 return self._set(scalingVec=value)
1042
1043 @since("2.0.0")
1044 def getScalingVec(self):
1045 """
1046 Gets the value of scalingVec or its default value.
1047 """
1048 return self.getOrDefault(self.scalingVec)
1049
1050 def setInputCol(self, value):
1051 """
1052 Sets the value of :py:attr:`inputCol`.
1053 """
1054 return self._set(inputCol=value)
1055
1056 def setOutputCol(self, value):
1057 """
1058 Sets the value of :py:attr:`outputCol`.
1059 """
1060 return self._set(outputCol=value)
1061
1062
1063 @inherit_doc
1064 class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable,
1065 JavaMLWritable):
1066 """
1067 Feature hashing projects a set of categorical or numerical features into a feature vector of
1068 specified dimension (typically substantially smaller than that of the original feature
1069 space). This is done using the hashing trick (https://en.wikipedia.org/wiki/Feature_hashing)
1070 to map features to indices in the feature vector.
1071
1072 The FeatureHasher transformer operates on multiple columns. Each column may contain either
1073 numeric or categorical features. Behavior and handling of column data types is as follows:
1074
1075 * Numeric columns:
1076 For numeric features, the hash value of the column name is used to map the
1077 feature value to its index in the feature vector. By default, numeric features
1078 are not treated as categorical (even when they are integers). To treat them
1079 as categorical, specify the relevant columns in `categoricalCols`.
1080
1081 * String columns:
1082 For categorical features, the hash value of the string "column_name=value"
1083 is used to map to the vector index, with an indicator value of `1.0`.
1084 Thus, categorical features are "one-hot" encoded
1085 (similarly to using :py:class:`OneHotEncoder` with `dropLast=false`).
1086
1087 * Boolean columns:
1088 Boolean values are treated in the same way as string columns. That is,
1089 boolean features are represented as "column_name=true" or "column_name=false",
1090 with an indicator value of `1.0`.
1091
1092 Null (missing) values are ignored (implicitly zero in the resulting feature vector).
1093
1094 Since a simple modulo is used to transform the hash function to a vector index,
1095 it is advisable to use a power of two as the `numFeatures` parameter;
1096 otherwise the features will not be mapped evenly to the vector indices.
1097
1098 >>> data = [(2.0, True, "1", "foo"), (3.0, False, "2", "bar")]
1099 >>> cols = ["real", "bool", "stringNum", "string"]
1100 >>> df = spark.createDataFrame(data, cols)
1101 >>> hasher = FeatureHasher()
1102 >>> hasher.setInputCols(cols)
1103 FeatureHasher...
1104 >>> hasher.setOutputCol("features")
1105 FeatureHasher...
1106 >>> hasher.transform(df).head().features
1107 SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
1108 >>> hasher.setCategoricalCols(["real"]).transform(df).head().features
1109 SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
1110 >>> hasherPath = temp_path + "/hasher"
1111 >>> hasher.save(hasherPath)
1112 >>> loadedHasher = FeatureHasher.load(hasherPath)
1113 >>> loadedHasher.getNumFeatures() == hasher.getNumFeatures()
1114 True
1115 >>> loadedHasher.transform(df).head().features == hasher.transform(df).head().features
1116 True
1117
1118 .. versionadded:: 2.3.0
1119 """
1120
1121 categoricalCols = Param(Params._dummy(), "categoricalCols",
1122 "numeric columns to treat as categorical",
1123 typeConverter=TypeConverters.toListString)
1124
1125 @keyword_only
1126 def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None):
1127 """
1128 __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None)
1129 """
1130 super(FeatureHasher, self).__init__()
1131 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid)
1132 self._setDefault(numFeatures=1 << 18)
1133 kwargs = self._input_kwargs
1134 self.setParams(**kwargs)
1135
1136 @keyword_only
1137 @since("2.3.0")
1138 def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None):
1139 """
1140 setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None)
1141 Sets params for this FeatureHasher.
1142 """
1143 kwargs = self._input_kwargs
1144 return self._set(**kwargs)
1145
1146 @since("2.3.0")
1147 def setCategoricalCols(self, value):
1148 """
1149 Sets the value of :py:attr:`categoricalCols`.
1150 """
1151 return self._set(categoricalCols=value)
1152
1153 @since("2.3.0")
1154 def getCategoricalCols(self):
1155 """
1156 Gets the value of binary or its default value.
1157 """
1158 return self.getOrDefault(self.categoricalCols)
1159
1160 def setInputCols(self, value):
1161 """
1162 Sets the value of :py:attr:`inputCols`.
1163 """
1164 return self._set(inputCols=value)
1165
1166 def setOutputCol(self, value):
1167 """
1168 Sets the value of :py:attr:`outputCol`.
1169 """
1170 return self._set(outputCol=value)
1171
1172 def setNumFeatures(self, value):
1173 """
1174 Sets the value of :py:attr:`numFeatures`.
1175 """
1176 return self._set(numFeatures=value)
1177
1178
1179 @inherit_doc
1180 class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
1181 JavaMLWritable):
1182 """
1183 Maps a sequence of terms to their term frequencies using the hashing trick.
1184 Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32)
1185 to calculate the hash code value for the term object.
1186 Since a simple modulo is used to transform the hash function to a column index,
1187 it is advisable to use a power of two as the numFeatures parameter;
1188 otherwise the features will not be mapped evenly to the columns.
1189
1190 >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"])
1191 >>> hashingTF = HashingTF(inputCol="words", outputCol="features")
1192 >>> hashingTF.setNumFeatures(10)
1193 HashingTF...
1194 >>> hashingTF.transform(df).head().features
1195 SparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})
1196 >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
1197 SparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})
1198 >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
1199 >>> hashingTF.transform(df, params).head().vector
1200 SparseVector(5, {0: 1.0, 2: 1.0, 3: 1.0})
1201 >>> hashingTFPath = temp_path + "/hashing-tf"
1202 >>> hashingTF.save(hashingTFPath)
1203 >>> loadedHashingTF = HashingTF.load(hashingTFPath)
1204 >>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures()
1205 True
1206 >>> hashingTF.indexOf("b")
1207 5
1208
1209 .. versionadded:: 1.3.0
1210 """
1211
1212 binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " +
1213 "This is useful for discrete probabilistic models that model binary events " +
1214 "rather than integer counts. Default False.",
1215 typeConverter=TypeConverters.toBoolean)
1216
1217 @keyword_only
1218 def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
1219 """
1220 __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None)
1221 """
1222 super(HashingTF, self).__init__()
1223 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
1224 self._setDefault(numFeatures=1 << 18, binary=False)
1225 kwargs = self._input_kwargs
1226 self.setParams(**kwargs)
1227
1228 @keyword_only
1229 @since("1.3.0")
1230 def setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
1231 """
1232 setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None)
1233 Sets params for this HashingTF.
1234 """
1235 kwargs = self._input_kwargs
1236 return self._set(**kwargs)
1237
1238 @since("2.0.0")
1239 def setBinary(self, value):
1240 """
1241 Sets the value of :py:attr:`binary`.
1242 """
1243 return self._set(binary=value)
1244
1245 @since("2.0.0")
1246 def getBinary(self):
1247 """
1248 Gets the value of binary or its default value.
1249 """
1250 return self.getOrDefault(self.binary)
1251
1252 def setInputCol(self, value):
1253 """
1254 Sets the value of :py:attr:`inputCol`.
1255 """
1256 return self._set(inputCol=value)
1257
1258 def setOutputCol(self, value):
1259 """
1260 Sets the value of :py:attr:`outputCol`.
1261 """
1262 return self._set(outputCol=value)
1263
1264 def setNumFeatures(self, value):
1265 """
1266 Sets the value of :py:attr:`numFeatures`.
1267 """
1268 return self._set(numFeatures=value)
1269
1270 @since("3.0.0")
1271 def indexOf(self, term):
1272 """
1273 Returns the index of the input term.
1274 """
1275 self._transfer_params_to_java()
1276 return self._java_obj.indexOf(term)
1277
1278
1279 class _IDFParams(HasInputCol, HasOutputCol):
1280 """
1281 Params for :py:class:`IDF` and :py:class:`IDFModel`.
1282
1283 .. versionadded:: 3.0.0
1284 """
1285
1286 minDocFreq = Param(Params._dummy(), "minDocFreq",
1287 "minimum number of documents in which a term should appear for filtering",
1288 typeConverter=TypeConverters.toInt)
1289
1290 @since("1.4.0")
1291 def getMinDocFreq(self):
1292 """
1293 Gets the value of minDocFreq or its default value.
1294 """
1295 return self.getOrDefault(self.minDocFreq)
1296
1297
1298 @inherit_doc
1299 class IDF(JavaEstimator, _IDFParams, JavaMLReadable, JavaMLWritable):
1300 """
1301 Compute the Inverse Document Frequency (IDF) given a collection of documents.
1302
1303 >>> from pyspark.ml.linalg import DenseVector
1304 >>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),),
1305 ... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"])
1306 >>> idf = IDF(minDocFreq=3)
1307 >>> idf.setInputCol("tf")
1308 IDF...
1309 >>> idf.setOutputCol("idf")
1310 IDF...
1311 >>> model = idf.fit(df)
1312 >>> model.setOutputCol("idf")
1313 IDFModel...
1314 >>> model.getMinDocFreq()
1315 3
1316 >>> model.idf
1317 DenseVector([0.0, 0.0])
1318 >>> model.docFreq
1319 [0, 3]
1320 >>> model.numDocs == df.count()
1321 True
1322 >>> model.transform(df).head().idf
1323 DenseVector([0.0, 0.0])
1324 >>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqs
1325 DenseVector([0.0, 0.0])
1326 >>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"}
1327 >>> idf.fit(df, params).transform(df).head().vector
1328 DenseVector([0.2877, 0.0])
1329 >>> idfPath = temp_path + "/idf"
1330 >>> idf.save(idfPath)
1331 >>> loadedIdf = IDF.load(idfPath)
1332 >>> loadedIdf.getMinDocFreq() == idf.getMinDocFreq()
1333 True
1334 >>> modelPath = temp_path + "/idf-model"
1335 >>> model.save(modelPath)
1336 >>> loadedModel = IDFModel.load(modelPath)
1337 >>> loadedModel.transform(df).head().idf == model.transform(df).head().idf
1338 True
1339
1340 .. versionadded:: 1.4.0
1341 """
1342
1343 @keyword_only
1344 def __init__(self, minDocFreq=0, inputCol=None, outputCol=None):
1345 """
1346 __init__(self, minDocFreq=0, inputCol=None, outputCol=None)
1347 """
1348 super(IDF, self).__init__()
1349 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid)
1350 self._setDefault(minDocFreq=0)
1351 kwargs = self._input_kwargs
1352 self.setParams(**kwargs)
1353
1354 @keyword_only
1355 @since("1.4.0")
1356 def setParams(self, minDocFreq=0, inputCol=None, outputCol=None):
1357 """
1358 setParams(self, minDocFreq=0, inputCol=None, outputCol=None)
1359 Sets params for this IDF.
1360 """
1361 kwargs = self._input_kwargs
1362 return self._set(**kwargs)
1363
1364 @since("1.4.0")
1365 def setMinDocFreq(self, value):
1366 """
1367 Sets the value of :py:attr:`minDocFreq`.
1368 """
1369 return self._set(minDocFreq=value)
1370
1371 def setInputCol(self, value):
1372 """
1373 Sets the value of :py:attr:`inputCol`.
1374 """
1375 return self._set(inputCol=value)
1376
1377 def setOutputCol(self, value):
1378 """
1379 Sets the value of :py:attr:`outputCol`.
1380 """
1381 return self._set(outputCol=value)
1382
1383 def _create_model(self, java_model):
1384 return IDFModel(java_model)
1385
1386
1387 class IDFModel(JavaModel, _IDFParams, JavaMLReadable, JavaMLWritable):
1388 """
1389 Model fitted by :py:class:`IDF`.
1390
1391 .. versionadded:: 1.4.0
1392 """
1393
1394 @since("3.0.0")
1395 def setInputCol(self, value):
1396 """
1397 Sets the value of :py:attr:`inputCol`.
1398 """
1399 return self._set(inputCol=value)
1400
1401 @since("3.0.0")
1402 def setOutputCol(self, value):
1403 """
1404 Sets the value of :py:attr:`outputCol`.
1405 """
1406 return self._set(outputCol=value)
1407
1408 @property
1409 @since("2.0.0")
1410 def idf(self):
1411 """
1412 Returns the IDF vector.
1413 """
1414 return self._call_java("idf")
1415
1416 @property
1417 @since("3.0.0")
1418 def docFreq(self):
1419 """
1420 Returns the document frequency.
1421 """
1422 return self._call_java("docFreq")
1423
1424 @property
1425 @since("3.0.0")
1426 def numDocs(self):
1427 """
1428 Returns number of documents evaluated to compute idf
1429 """
1430 return self._call_java("numDocs")
1431
1432
1433 class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasRelativeError):
1434 """
1435 Params for :py:class:`Imputer` and :py:class:`ImputerModel`.
1436
1437 .. versionadded:: 3.0.0
1438 """
1439
1440 strategy = Param(Params._dummy(), "strategy",
1441 "strategy for imputation. If mean, then replace missing values using the mean "
1442 "value of the feature. If median, then replace missing values using the "
1443 "median value of the feature.",
1444 typeConverter=TypeConverters.toString)
1445
1446 missingValue = Param(Params._dummy(), "missingValue",
1447 "The placeholder for the missing values. All occurrences of missingValue "
1448 "will be imputed.", typeConverter=TypeConverters.toFloat)
1449
1450 @since("2.2.0")
1451 def getStrategy(self):
1452 """
1453 Gets the value of :py:attr:`strategy` or its default value.
1454 """
1455 return self.getOrDefault(self.strategy)
1456
1457 @since("2.2.0")
1458 def getMissingValue(self):
1459 """
1460 Gets the value of :py:attr:`missingValue` or its default value.
1461 """
1462 return self.getOrDefault(self.missingValue)
1463
1464
1465 @inherit_doc
1466 class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
1467 """
1468 Imputation estimator for completing missing values, either using the mean or the median
1469 of the columns in which the missing values are located. The input columns should be of
1470 numeric type. Currently Imputer does not support categorical features and
1471 possibly creates incorrect values for a categorical feature.
1472
1473 Note that the mean/median value is computed after filtering out missing values.
1474 All Null values in the input columns are treated as missing, and so are also imputed. For
1475 computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a
1476 relative error of `0.001`.
1477
1478 >>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0),
1479 ... (4.0, 4.0), (5.0, 5.0)], ["a", "b"])
1480 >>> imputer = Imputer()
1481 >>> imputer.setInputCols(["a", "b"])
1482 Imputer...
1483 >>> imputer.setOutputCols(["out_a", "out_b"])
1484 Imputer...
1485 >>> imputer.getRelativeError()
1486 0.001
1487 >>> model = imputer.fit(df)
1488 >>> model.setInputCols(["a", "b"])
1489 ImputerModel...
1490 >>> model.getStrategy()
1491 'mean'
1492 >>> model.surrogateDF.show()
1493 +---+---+
1494 | a| b|
1495 +---+---+
1496 |3.0|4.0|
1497 +---+---+
1498 ...
1499 >>> model.transform(df).show()
1500 +---+---+-----+-----+
1501 | a| b|out_a|out_b|
1502 +---+---+-----+-----+
1503 |1.0|NaN| 1.0| 4.0|
1504 |2.0|NaN| 2.0| 4.0|
1505 |NaN|3.0| 3.0| 3.0|
1506 ...
1507 >>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).show()
1508 +---+---+-----+-----+
1509 | a| b|out_a|out_b|
1510 +---+---+-----+-----+
1511 |1.0|NaN| 4.0| NaN|
1512 ...
1513 >>> df1 = spark.createDataFrame([(1.0,), (2.0,), (float("nan"),), (4.0,), (5.0,)], ["a"])
1514 >>> imputer1 = Imputer(inputCol="a", outputCol="out_a")
1515 >>> model1 = imputer1.fit(df1)
1516 >>> model1.surrogateDF.show()
1517 +---+
1518 | a|
1519 +---+
1520 |3.0|
1521 +---+
1522 ...
1523 >>> model1.transform(df1).show()
1524 +---+-----+
1525 | a|out_a|
1526 +---+-----+
1527 |1.0| 1.0|
1528 |2.0| 2.0|
1529 |NaN| 3.0|
1530 ...
1531 >>> imputer1.setStrategy("median").setMissingValue(1.0).fit(df1).transform(df1).show()
1532 +---+-----+
1533 | a|out_a|
1534 +---+-----+
1535 |1.0| 4.0|
1536 ...
1537 >>> df2 = spark.createDataFrame([(float("nan"),), (float("nan"),), (3.0,), (4.0,), (5.0,)],
1538 ... ["b"])
1539 >>> imputer2 = Imputer(inputCol="b", outputCol="out_b")
1540 >>> model2 = imputer2.fit(df2)
1541 >>> model2.surrogateDF.show()
1542 +---+
1543 | b|
1544 +---+
1545 |4.0|
1546 +---+
1547 ...
1548 >>> model2.transform(df2).show()
1549 +---+-----+
1550 | b|out_b|
1551 +---+-----+
1552 |NaN| 4.0|
1553 |NaN| 4.0|
1554 |3.0| 3.0|
1555 ...
1556 >>> imputer2.setStrategy("median").setMissingValue(1.0).fit(df2).transform(df2).show()
1557 +---+-----+
1558 | b|out_b|
1559 +---+-----+
1560 |NaN| NaN|
1561 ...
1562 >>> imputerPath = temp_path + "/imputer"
1563 >>> imputer.save(imputerPath)
1564 >>> loadedImputer = Imputer.load(imputerPath)
1565 >>> loadedImputer.getStrategy() == imputer.getStrategy()
1566 True
1567 >>> loadedImputer.getMissingValue()
1568 1.0
1569 >>> modelPath = temp_path + "/imputer-model"
1570 >>> model.save(modelPath)
1571 >>> loadedModel = ImputerModel.load(modelPath)
1572 >>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_a
1573 True
1574
1575 .. versionadded:: 2.2.0
1576 """
1577
1578 @keyword_only
1579 def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None,
1580 outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
1581 """
1582 __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, \
1583 outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
1584 """
1585 super(Imputer, self).__init__()
1586 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid)
1587 self._setDefault(strategy="mean", missingValue=float("nan"), relativeError=0.001)
1588 kwargs = self._input_kwargs
1589 self.setParams(**kwargs)
1590
1591 @keyword_only
1592 @since("2.2.0")
1593 def setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None,
1594 outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
1595 """
1596 setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, \
1597 outputCols=None, inputCol=None, outputCol=None, relativeError=0.001)
1598 Sets params for this Imputer.
1599 """
1600 kwargs = self._input_kwargs
1601 return self._set(**kwargs)
1602
1603 @since("2.2.0")
1604 def setStrategy(self, value):
1605 """
1606 Sets the value of :py:attr:`strategy`.
1607 """
1608 return self._set(strategy=value)
1609
1610 @since("2.2.0")
1611 def setMissingValue(self, value):
1612 """
1613 Sets the value of :py:attr:`missingValue`.
1614 """
1615 return self._set(missingValue=value)
1616
1617 @since("2.2.0")
1618 def setInputCols(self, value):
1619 """
1620 Sets the value of :py:attr:`inputCols`.
1621 """
1622 return self._set(inputCols=value)
1623
1624 @since("2.2.0")
1625 def setOutputCols(self, value):
1626 """
1627 Sets the value of :py:attr:`outputCols`.
1628 """
1629 return self._set(outputCols=value)
1630
1631 @since("3.0.0")
1632 def setInputCol(self, value):
1633 """
1634 Sets the value of :py:attr:`inputCol`.
1635 """
1636 return self._set(inputCol=value)
1637
1638 @since("3.0.0")
1639 def setOutputCol(self, value):
1640 """
1641 Sets the value of :py:attr:`outputCol`.
1642 """
1643 return self._set(outputCol=value)
1644
1645 @since("3.0.0")
1646 def setRelativeError(self, value):
1647 """
1648 Sets the value of :py:attr:`relativeError`.
1649 """
1650 return self._set(relativeError=value)
1651
1652 def _create_model(self, java_model):
1653 return ImputerModel(java_model)
1654
1655
1656 class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable, JavaMLWritable):
1657 """
1658 Model fitted by :py:class:`Imputer`.
1659
1660 .. versionadded:: 2.2.0
1661 """
1662
1663 @since("3.0.0")
1664 def setInputCols(self, value):
1665 """
1666 Sets the value of :py:attr:`inputCols`.
1667 """
1668 return self._set(inputCols=value)
1669
1670 @since("3.0.0")
1671 def setOutputCols(self, value):
1672 """
1673 Sets the value of :py:attr:`outputCols`.
1674 """
1675 return self._set(outputCols=value)
1676
1677 @since("3.0.0")
1678 def setInputCol(self, value):
1679 """
1680 Sets the value of :py:attr:`inputCol`.
1681 """
1682 return self._set(inputCol=value)
1683
1684 @since("3.0.0")
1685 def setOutputCol(self, value):
1686 """
1687 Sets the value of :py:attr:`outputCol`.
1688 """
1689 return self._set(outputCol=value)
1690
1691 @property
1692 @since("2.2.0")
1693 def surrogateDF(self):
1694 """
1695 Returns a DataFrame containing inputCols and their corresponding surrogates,
1696 which are used to replace the missing values in the input DataFrame.
1697 """
1698 return self._call_java("surrogateDF")
1699
1700
1701 @inherit_doc
1702 class Interaction(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
1703 """
1704 Implements the feature interaction transform. This transformer takes in Double and Vector type
1705 columns and outputs a flattened vector of their feature interactions. To handle interaction,
1706 we first one-hot encode any nominal features. Then, a vector of the feature cross-products is
1707 produced.
1708
1709 For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be
1710 `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal
1711 with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`.
1712
1713 >>> df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"])
1714 >>> interaction = Interaction()
1715 >>> interaction.setInputCols(["a", "b"])
1716 Interaction...
1717 >>> interaction.setOutputCol("ab")
1718 Interaction...
1719 >>> interaction.transform(df).show()
1720 +---+---+-----+
1721 | a| b| ab|
1722 +---+---+-----+
1723 |0.0|1.0|[0.0]|
1724 |2.0|3.0|[6.0]|
1725 +---+---+-----+
1726 ...
1727 >>> interactionPath = temp_path + "/interaction"
1728 >>> interaction.save(interactionPath)
1729 >>> loadedInteraction = Interaction.load(interactionPath)
1730 >>> loadedInteraction.transform(df).head().ab == interaction.transform(df).head().ab
1731 True
1732
1733 .. versionadded:: 3.0.0
1734 """
1735
1736 @keyword_only
1737 def __init__(self, inputCols=None, outputCol=None):
1738 """
1739 __init__(self, inputCols=None, outputCol=None):
1740 """
1741 super(Interaction, self).__init__()
1742 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Interaction", self.uid)
1743 self._setDefault()
1744 kwargs = self._input_kwargs
1745 self.setParams(**kwargs)
1746
1747 @keyword_only
1748 @since("3.0.0")
1749 def setParams(self, inputCols=None, outputCol=None):
1750 """
1751 setParams(self, inputCols=None, outputCol=None)
1752 Sets params for this Interaction.
1753 """
1754 kwargs = self._input_kwargs
1755 return self._set(**kwargs)
1756
1757 @since("3.0.0")
1758 def setInputCols(self, value):
1759 """
1760 Sets the value of :py:attr:`inputCols`.
1761 """
1762 return self._set(inputCols=value)
1763
1764 @since("3.0.0")
1765 def setOutputCol(self, value):
1766 """
1767 Sets the value of :py:attr:`outputCol`.
1768 """
1769 return self._set(outputCol=value)
1770
1771
1772 class _MaxAbsScalerParams(HasInputCol, HasOutputCol):
1773 """
1774 Params for :py:class:`MaxAbsScaler` and :py:class:`MaxAbsScalerModel`.
1775
1776 .. versionadded:: 3.0.0
1777 """
1778 pass
1779
1780
1781 @inherit_doc
1782 class MaxAbsScaler(JavaEstimator, _MaxAbsScalerParams, JavaMLReadable, JavaMLWritable):
1783 """
1784 Rescale each feature individually to range [-1, 1] by dividing through the largest maximum
1785 absolute value in each feature. It does not shift/center the data, and thus does not destroy
1786 any sparsity.
1787
1788 >>> from pyspark.ml.linalg import Vectors
1789 >>> df = spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"])
1790 >>> maScaler = MaxAbsScaler(outputCol="scaled")
1791 >>> maScaler.setInputCol("a")
1792 MaxAbsScaler...
1793 >>> model = maScaler.fit(df)
1794 >>> model.setOutputCol("scaledOutput")
1795 MaxAbsScalerModel...
1796 >>> model.transform(df).show()
1797 +-----+------------+
1798 | a|scaledOutput|
1799 +-----+------------+
1800 |[1.0]| [0.5]|
1801 |[2.0]| [1.0]|
1802 +-----+------------+
1803 ...
1804 >>> scalerPath = temp_path + "/max-abs-scaler"
1805 >>> maScaler.save(scalerPath)
1806 >>> loadedMAScaler = MaxAbsScaler.load(scalerPath)
1807 >>> loadedMAScaler.getInputCol() == maScaler.getInputCol()
1808 True
1809 >>> loadedMAScaler.getOutputCol() == maScaler.getOutputCol()
1810 True
1811 >>> modelPath = temp_path + "/max-abs-scaler-model"
1812 >>> model.save(modelPath)
1813 >>> loadedModel = MaxAbsScalerModel.load(modelPath)
1814 >>> loadedModel.maxAbs == model.maxAbs
1815 True
1816
1817 .. versionadded:: 2.0.0
1818 """
1819
1820 @keyword_only
1821 def __init__(self, inputCol=None, outputCol=None):
1822 """
1823 __init__(self, inputCol=None, outputCol=None)
1824 """
1825 super(MaxAbsScaler, self).__init__()
1826 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid)
1827 self._setDefault()
1828 kwargs = self._input_kwargs
1829 self.setParams(**kwargs)
1830
1831 @keyword_only
1832 @since("2.0.0")
1833 def setParams(self, inputCol=None, outputCol=None):
1834 """
1835 setParams(self, inputCol=None, outputCol=None)
1836 Sets params for this MaxAbsScaler.
1837 """
1838 kwargs = self._input_kwargs
1839 return self._set(**kwargs)
1840
1841 def setInputCol(self, value):
1842 """
1843 Sets the value of :py:attr:`inputCol`.
1844 """
1845 return self._set(inputCol=value)
1846
1847 def setOutputCol(self, value):
1848 """
1849 Sets the value of :py:attr:`outputCol`.
1850 """
1851 return self._set(outputCol=value)
1852
1853 def _create_model(self, java_model):
1854 return MaxAbsScalerModel(java_model)
1855
1856
1857 class MaxAbsScalerModel(JavaModel, _MaxAbsScalerParams, JavaMLReadable, JavaMLWritable):
1858 """
1859 Model fitted by :py:class:`MaxAbsScaler`.
1860
1861 .. versionadded:: 2.0.0
1862 """
1863
1864 @since("3.0.0")
1865 def setInputCol(self, value):
1866 """
1867 Sets the value of :py:attr:`inputCol`.
1868 """
1869 return self._set(inputCol=value)
1870
1871 @since("3.0.0")
1872 def setOutputCol(self, value):
1873 """
1874 Sets the value of :py:attr:`outputCol`.
1875 """
1876 return self._set(outputCol=value)
1877
1878 @property
1879 @since("2.0.0")
1880 def maxAbs(self):
1881 """
1882 Max Abs vector.
1883 """
1884 return self._call_java("maxAbs")
1885
1886
1887 @inherit_doc
1888 class MinHashLSH(_LSH, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, JavaMLWritable):
1889
1890 """
1891 LSH class for Jaccard distance.
1892 The input can be dense or sparse vectors, but it is more efficient if it is sparse.
1893 For example, `Vectors.sparse(10, [(2, 1.0), (3, 1.0), (5, 1.0)])` means there are 10 elements
1894 in the space. This set contains elements 2, 3, and 5. Also, any input vector must have at
1895 least 1 non-zero index, and all non-zero values are treated as binary "1" values.
1896
1897 .. seealso:: `Wikipedia on MinHash <https://en.wikipedia.org/wiki/MinHash>`_
1898
1899 >>> from pyspark.ml.linalg import Vectors
1900 >>> from pyspark.sql.functions import col
1901 >>> data = [(0, Vectors.sparse(6, [0, 1, 2], [1.0, 1.0, 1.0]),),
1902 ... (1, Vectors.sparse(6, [2, 3, 4], [1.0, 1.0, 1.0]),),
1903 ... (2, Vectors.sparse(6, [0, 2, 4], [1.0, 1.0, 1.0]),)]
1904 >>> df = spark.createDataFrame(data, ["id", "features"])
1905 >>> mh = MinHashLSH()
1906 >>> mh.setInputCol("features")
1907 MinHashLSH...
1908 >>> mh.setOutputCol("hashes")
1909 MinHashLSH...
1910 >>> mh.setSeed(12345)
1911 MinHashLSH...
1912 >>> model = mh.fit(df)
1913 >>> model.setInputCol("features")
1914 MinHashLSHModel...
1915 >>> model.transform(df).head()
1916 Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668...
1917 >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),
1918 ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),),
1919 ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)]
1920 >>> df2 = spark.createDataFrame(data2, ["id", "features"])
1921 >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0])
1922 >>> model.approxNearestNeighbors(df2, key, 1).collect()
1923 [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668...
1924 >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select(
1925 ... col("datasetA.id").alias("idA"),
1926 ... col("datasetB.id").alias("idB"),
1927 ... col("JaccardDistance")).show()
1928 +---+---+---------------+
1929 |idA|idB|JaccardDistance|
1930 +---+---+---------------+
1931 | 0| 5| 0.5|
1932 | 1| 4| 0.5|
1933 +---+---+---------------+
1934 ...
1935 >>> mhPath = temp_path + "/mh"
1936 >>> mh.save(mhPath)
1937 >>> mh2 = MinHashLSH.load(mhPath)
1938 >>> mh2.getOutputCol() == mh.getOutputCol()
1939 True
1940 >>> modelPath = temp_path + "/mh-model"
1941 >>> model.save(modelPath)
1942 >>> model2 = MinHashLSHModel.load(modelPath)
1943 >>> model.transform(df).head().hashes == model2.transform(df).head().hashes
1944 True
1945
1946 .. versionadded:: 2.2.0
1947 """
1948
1949 @keyword_only
1950 def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1):
1951 """
1952 __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1)
1953 """
1954 super(MinHashLSH, self).__init__()
1955 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinHashLSH", self.uid)
1956 self._setDefault(numHashTables=1)
1957 kwargs = self._input_kwargs
1958 self.setParams(**kwargs)
1959
1960 @keyword_only
1961 @since("2.2.0")
1962 def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1):
1963 """
1964 setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1)
1965 Sets params for this MinHashLSH.
1966 """
1967 kwargs = self._input_kwargs
1968 return self._set(**kwargs)
1969
1970 def setSeed(self, value):
1971 """
1972 Sets the value of :py:attr:`seed`.
1973 """
1974 return self._set(seed=value)
1975
1976 def _create_model(self, java_model):
1977 return MinHashLSHModel(java_model)
1978
1979
1980 class MinHashLSHModel(_LSHModel, JavaMLReadable, JavaMLWritable):
1981 r"""
1982 Model produced by :py:class:`MinHashLSH`, where where multiple hash functions are stored. Each
1983 hash function is picked from the following family of hash functions, where :math:`a_i` and
1984 :math:`b_i` are randomly chosen integers less than prime:
1985 :math:`h_i(x) = ((x \cdot a_i + b_i) \mod prime)` This hash family is approximately min-wise
1986 independent according to the reference.
1987
1988 .. seealso:: Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear
1989 permutations." Electronic Journal of Combinatorics 7 (2000): R26.
1990
1991 .. versionadded:: 2.2.0
1992 """
1993
1994
1995 class _MinMaxScalerParams(HasInputCol, HasOutputCol):
1996 """
1997 Params for :py:class:`MinMaxScaler` and :py:class:`MinMaxScalerModel`.
1998
1999 .. versionadded:: 3.0.0
2000 """
2001
2002 min = Param(Params._dummy(), "min", "Lower bound of the output feature range",
2003 typeConverter=TypeConverters.toFloat)
2004 max = Param(Params._dummy(), "max", "Upper bound of the output feature range",
2005 typeConverter=TypeConverters.toFloat)
2006
2007 @since("1.6.0")
2008 def getMin(self):
2009 """
2010 Gets the value of min or its default value.
2011 """
2012 return self.getOrDefault(self.min)
2013
2014 @since("1.6.0")
2015 def getMax(self):
2016 """
2017 Gets the value of max or its default value.
2018 """
2019 return self.getOrDefault(self.max)
2020
2021
2022 @inherit_doc
2023 class MinMaxScaler(JavaEstimator, _MinMaxScalerParams, JavaMLReadable, JavaMLWritable):
2024 """
2025 Rescale each feature individually to a common range [min, max] linearly using column summary
2026 statistics, which is also known as min-max normalization or Rescaling. The rescaled value for
2027 feature E is calculated as,
2028
2029 Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min
2030
2031 For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min)
2032
2033 .. note:: Since zero values will probably be transformed to non-zero values, output of the
2034 transformer will be DenseVector even for sparse input.
2035
2036 >>> from pyspark.ml.linalg import Vectors
2037 >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
2038 >>> mmScaler = MinMaxScaler(outputCol="scaled")
2039 >>> mmScaler.setInputCol("a")
2040 MinMaxScaler...
2041 >>> model = mmScaler.fit(df)
2042 >>> model.setOutputCol("scaledOutput")
2043 MinMaxScalerModel...
2044 >>> model.originalMin
2045 DenseVector([0.0])
2046 >>> model.originalMax
2047 DenseVector([2.0])
2048 >>> model.transform(df).show()
2049 +-----+------------+
2050 | a|scaledOutput|
2051 +-----+------------+
2052 |[0.0]| [0.0]|
2053 |[2.0]| [1.0]|
2054 +-----+------------+
2055 ...
2056 >>> minMaxScalerPath = temp_path + "/min-max-scaler"
2057 >>> mmScaler.save(minMaxScalerPath)
2058 >>> loadedMMScaler = MinMaxScaler.load(minMaxScalerPath)
2059 >>> loadedMMScaler.getMin() == mmScaler.getMin()
2060 True
2061 >>> loadedMMScaler.getMax() == mmScaler.getMax()
2062 True
2063 >>> modelPath = temp_path + "/min-max-scaler-model"
2064 >>> model.save(modelPath)
2065 >>> loadedModel = MinMaxScalerModel.load(modelPath)
2066 >>> loadedModel.originalMin == model.originalMin
2067 True
2068 >>> loadedModel.originalMax == model.originalMax
2069 True
2070
2071 .. versionadded:: 1.6.0
2072 """
2073
2074 @keyword_only
2075 def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None):
2076 """
2077 __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None)
2078 """
2079 super(MinMaxScaler, self).__init__()
2080 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid)
2081 self._setDefault(min=0.0, max=1.0)
2082 kwargs = self._input_kwargs
2083 self.setParams(**kwargs)
2084
2085 @keyword_only
2086 @since("1.6.0")
2087 def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None):
2088 """
2089 setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None)
2090 Sets params for this MinMaxScaler.
2091 """
2092 kwargs = self._input_kwargs
2093 return self._set(**kwargs)
2094
2095 @since("1.6.0")
2096 def setMin(self, value):
2097 """
2098 Sets the value of :py:attr:`min`.
2099 """
2100 return self._set(min=value)
2101
2102 @since("1.6.0")
2103 def setMax(self, value):
2104 """
2105 Sets the value of :py:attr:`max`.
2106 """
2107 return self._set(max=value)
2108
2109 def setInputCol(self, value):
2110 """
2111 Sets the value of :py:attr:`inputCol`.
2112 """
2113 return self._set(inputCol=value)
2114
2115 def setOutputCol(self, value):
2116 """
2117 Sets the value of :py:attr:`outputCol`.
2118 """
2119 return self._set(outputCol=value)
2120
2121 def _create_model(self, java_model):
2122 return MinMaxScalerModel(java_model)
2123
2124
2125 class MinMaxScalerModel(JavaModel, _MinMaxScalerParams, JavaMLReadable, JavaMLWritable):
2126 """
2127 Model fitted by :py:class:`MinMaxScaler`.
2128
2129 .. versionadded:: 1.6.0
2130 """
2131
2132 @since("3.0.0")
2133 def setInputCol(self, value):
2134 """
2135 Sets the value of :py:attr:`inputCol`.
2136 """
2137 return self._set(inputCol=value)
2138
2139 @since("3.0.0")
2140 def setOutputCol(self, value):
2141 """
2142 Sets the value of :py:attr:`outputCol`.
2143 """
2144 return self._set(outputCol=value)
2145
2146 @since("3.0.0")
2147 def setMin(self, value):
2148 """
2149 Sets the value of :py:attr:`min`.
2150 """
2151 return self._set(min=value)
2152
2153 @since("3.0.0")
2154 def setMax(self, value):
2155 """
2156 Sets the value of :py:attr:`max`.
2157 """
2158 return self._set(max=value)
2159
2160 @property
2161 @since("2.0.0")
2162 def originalMin(self):
2163 """
2164 Min value for each original column during fitting.
2165 """
2166 return self._call_java("originalMin")
2167
2168 @property
2169 @since("2.0.0")
2170 def originalMax(self):
2171 """
2172 Max value for each original column during fitting.
2173 """
2174 return self._call_java("originalMax")
2175
2176
2177 @inherit_doc
2178 @ignore_unicode_prefix
2179 class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
2180 """
2181 A feature transformer that converts the input array of strings into an array of n-grams. Null
2182 values in the input array are ignored.
2183 It returns an array of n-grams where each n-gram is represented by a space-separated string of
2184 words.
2185 When the input is empty, an empty array is returned.
2186 When the input array length is less than n (number of elements per n-gram), no n-grams are
2187 returned.
2188
2189 >>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
2190 >>> ngram = NGram(n=2)
2191 >>> ngram.setInputCol("inputTokens")
2192 NGram...
2193 >>> ngram.setOutputCol("nGrams")
2194 NGram...
2195 >>> ngram.transform(df).head()
2196 Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
2197 >>> # Change n-gram length
2198 >>> ngram.setParams(n=4).transform(df).head()
2199 Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
2200 >>> # Temporarily modify output column.
2201 >>> ngram.transform(df, {ngram.outputCol: "output"}).head()
2202 Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e'])
2203 >>> ngram.transform(df).head()
2204 Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
2205 >>> # Must use keyword arguments to specify params.
2206 >>> ngram.setParams("text")
2207 Traceback (most recent call last):
2208 ...
2209 TypeError: Method setParams forces keyword arguments.
2210 >>> ngramPath = temp_path + "/ngram"
2211 >>> ngram.save(ngramPath)
2212 >>> loadedNGram = NGram.load(ngramPath)
2213 >>> loadedNGram.getN() == ngram.getN()
2214 True
2215
2216 .. versionadded:: 1.5.0
2217 """
2218
2219 n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)",
2220 typeConverter=TypeConverters.toInt)
2221
2222 @keyword_only
2223 def __init__(self, n=2, inputCol=None, outputCol=None):
2224 """
2225 __init__(self, n=2, inputCol=None, outputCol=None)
2226 """
2227 super(NGram, self).__init__()
2228 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
2229 self._setDefault(n=2)
2230 kwargs = self._input_kwargs
2231 self.setParams(**kwargs)
2232
2233 @keyword_only
2234 @since("1.5.0")
2235 def setParams(self, n=2, inputCol=None, outputCol=None):
2236 """
2237 setParams(self, n=2, inputCol=None, outputCol=None)
2238 Sets params for this NGram.
2239 """
2240 kwargs = self._input_kwargs
2241 return self._set(**kwargs)
2242
2243 @since("1.5.0")
2244 def setN(self, value):
2245 """
2246 Sets the value of :py:attr:`n`.
2247 """
2248 return self._set(n=value)
2249
2250 @since("1.5.0")
2251 def getN(self):
2252 """
2253 Gets the value of n or its default value.
2254 """
2255 return self.getOrDefault(self.n)
2256
2257 def setInputCol(self, value):
2258 """
2259 Sets the value of :py:attr:`inputCol`.
2260 """
2261 return self._set(inputCol=value)
2262
2263 def setOutputCol(self, value):
2264 """
2265 Sets the value of :py:attr:`outputCol`.
2266 """
2267 return self._set(outputCol=value)
2268
2269
2270 @inherit_doc
2271 class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
2272 """
2273 Normalize a vector to have unit norm using the given p-norm.
2274
2275 >>> from pyspark.ml.linalg import Vectors
2276 >>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0})
2277 >>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense", "sparse"])
2278 >>> normalizer = Normalizer(p=2.0)
2279 >>> normalizer.setInputCol("dense")
2280 Normalizer...
2281 >>> normalizer.setOutputCol("features")
2282 Normalizer...
2283 >>> normalizer.transform(df).head().features
2284 DenseVector([0.6, -0.8])
2285 >>> normalizer.setParams(inputCol="sparse", outputCol="freqs").transform(df).head().freqs
2286 SparseVector(4, {1: 0.8, 3: 0.6})
2287 >>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.outputCol: "vector"}
2288 >>> normalizer.transform(df, params).head().vector
2289 DenseVector([0.4286, -0.5714])
2290 >>> normalizerPath = temp_path + "/normalizer"
2291 >>> normalizer.save(normalizerPath)
2292 >>> loadedNormalizer = Normalizer.load(normalizerPath)
2293 >>> loadedNormalizer.getP() == normalizer.getP()
2294 True
2295
2296 .. versionadded:: 1.4.0
2297 """
2298
2299 p = Param(Params._dummy(), "p", "the p norm value.",
2300 typeConverter=TypeConverters.toFloat)
2301
2302 @keyword_only
2303 def __init__(self, p=2.0, inputCol=None, outputCol=None):
2304 """
2305 __init__(self, p=2.0, inputCol=None, outputCol=None)
2306 """
2307 super(Normalizer, self).__init__()
2308 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid)
2309 self._setDefault(p=2.0)
2310 kwargs = self._input_kwargs
2311 self.setParams(**kwargs)
2312
2313 @keyword_only
2314 @since("1.4.0")
2315 def setParams(self, p=2.0, inputCol=None, outputCol=None):
2316 """
2317 setParams(self, p=2.0, inputCol=None, outputCol=None)
2318 Sets params for this Normalizer.
2319 """
2320 kwargs = self._input_kwargs
2321 return self._set(**kwargs)
2322
2323 @since("1.4.0")
2324 def setP(self, value):
2325 """
2326 Sets the value of :py:attr:`p`.
2327 """
2328 return self._set(p=value)
2329
2330 @since("1.4.0")
2331 def getP(self):
2332 """
2333 Gets the value of p or its default value.
2334 """
2335 return self.getOrDefault(self.p)
2336
2337 def setInputCol(self, value):
2338 """
2339 Sets the value of :py:attr:`inputCol`.
2340 """
2341 return self._set(inputCol=value)
2342
2343 def setOutputCol(self, value):
2344 """
2345 Sets the value of :py:attr:`outputCol`.
2346 """
2347 return self._set(outputCol=value)
2348
2349
2350 class _OneHotEncoderParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols,
2351 HasHandleInvalid):
2352 """
2353 Params for :py:class:`OneHotEncoder` and :py:class:`OneHotEncoderModel`.
2354
2355 .. versionadded:: 3.0.0
2356 """
2357
2358 handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " +
2359 "transform(). Options are 'keep' (invalid data presented as an extra " +
2360 "categorical feature) or error (throw an error). Note that this Param " +
2361 "is only used during transform; during fitting, invalid data will " +
2362 "result in an error.",
2363 typeConverter=TypeConverters.toString)
2364
2365 dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category",
2366 typeConverter=TypeConverters.toBoolean)
2367
2368 @since("2.3.0")
2369 def getDropLast(self):
2370 """
2371 Gets the value of dropLast or its default value.
2372 """
2373 return self.getOrDefault(self.dropLast)
2374
2375
2376 @inherit_doc
2377 class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLWritable):
2378 """
2379 A one-hot encoder that maps a column of category indices to a column of binary vectors, with
2380 at most a single one-value per row that indicates the input category index.
2381 For example with 5 categories, an input value of 2.0 would map to an output vector of
2382 `[0.0, 0.0, 1.0, 0.0]`.
2383 The last category is not included by default (configurable via :py:attr:`dropLast`),
2384 because it makes the vector entries sum up to one, and hence linearly dependent.
2385 So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
2386
2387 .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories.
2388 The output vectors are sparse.
2389
2390 When :py:attr:`handleInvalid` is configured to 'keep', an extra "category" indicating invalid
2391 values is added as last category. So when :py:attr:`dropLast` is true, invalid values are
2392 encoded as all-zeros vector.
2393
2394 .. note:: When encoding multi-column by using :py:attr:`inputCols` and
2395 :py:attr:`outputCols` params, input/output cols come in pairs, specified by the order in
2396 the arrays, and each pair is treated independently.
2397
2398 .. seealso:: :py:class:`StringIndexer` for converting categorical values into category indices
2399
2400 >>> from pyspark.ml.linalg import Vectors
2401 >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])
2402 >>> ohe = OneHotEncoder()
2403 >>> ohe.setInputCols(["input"])
2404 OneHotEncoder...
2405 >>> ohe.setOutputCols(["output"])
2406 OneHotEncoder...
2407 >>> model = ohe.fit(df)
2408 >>> model.setOutputCols(["output"])
2409 OneHotEncoderModel...
2410 >>> model.getHandleInvalid()
2411 'error'
2412 >>> model.transform(df).head().output
2413 SparseVector(2, {0: 1.0})
2414 >>> single_col_ohe = OneHotEncoder(inputCol="input", outputCol="output")
2415 >>> single_col_model = single_col_ohe.fit(df)
2416 >>> single_col_model.transform(df).head().output
2417 SparseVector(2, {0: 1.0})
2418 >>> ohePath = temp_path + "/ohe"
2419 >>> ohe.save(ohePath)
2420 >>> loadedOHE = OneHotEncoder.load(ohePath)
2421 >>> loadedOHE.getInputCols() == ohe.getInputCols()
2422 True
2423 >>> modelPath = temp_path + "/ohe-model"
2424 >>> model.save(modelPath)
2425 >>> loadedModel = OneHotEncoderModel.load(modelPath)
2426 >>> loadedModel.categorySizes == model.categorySizes
2427 True
2428
2429 .. versionadded:: 2.3.0
2430 """
2431
2432 @keyword_only
2433 def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True,
2434 inputCol=None, outputCol=None):
2435 """
2436 __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \
2437 inputCol=None, outputCol=None)
2438 """
2439 super(OneHotEncoder, self).__init__()
2440 self._java_obj = self._new_java_obj(
2441 "org.apache.spark.ml.feature.OneHotEncoder", self.uid)
2442 self._setDefault(handleInvalid="error", dropLast=True)
2443 kwargs = self._input_kwargs
2444 self.setParams(**kwargs)
2445
2446 @keyword_only
2447 @since("2.3.0")
2448 def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True,
2449 inputCol=None, outputCol=None):
2450 """
2451 setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \
2452 inputCol=None, outputCol=None)
2453 Sets params for this OneHotEncoder.
2454 """
2455 kwargs = self._input_kwargs
2456 return self._set(**kwargs)
2457
2458 @since("2.3.0")
2459 def setDropLast(self, value):
2460 """
2461 Sets the value of :py:attr:`dropLast`.
2462 """
2463 return self._set(dropLast=value)
2464
2465 @since("3.0.0")
2466 def setInputCols(self, value):
2467 """
2468 Sets the value of :py:attr:`inputCols`.
2469 """
2470 return self._set(inputCols=value)
2471
2472 @since("3.0.0")
2473 def setOutputCols(self, value):
2474 """
2475 Sets the value of :py:attr:`outputCols`.
2476 """
2477 return self._set(outputCols=value)
2478
2479 @since("3.0.0")
2480 def setHandleInvalid(self, value):
2481 """
2482 Sets the value of :py:attr:`handleInvalid`.
2483 """
2484 return self._set(handleInvalid=value)
2485
2486 @since("3.0.0")
2487 def setInputCol(self, value):
2488 """
2489 Sets the value of :py:attr:`inputCol`.
2490 """
2491 return self._set(inputCol=value)
2492
2493 @since("3.0.0")
2494 def setOutputCol(self, value):
2495 """
2496 Sets the value of :py:attr:`outputCol`.
2497 """
2498 return self._set(outputCol=value)
2499
2500 def _create_model(self, java_model):
2501 return OneHotEncoderModel(java_model)
2502
2503
2504 class OneHotEncoderModel(JavaModel, _OneHotEncoderParams, JavaMLReadable, JavaMLWritable):
2505 """
2506 Model fitted by :py:class:`OneHotEncoder`.
2507
2508 .. versionadded:: 2.3.0
2509 """
2510
2511 @since("3.0.0")
2512 def setDropLast(self, value):
2513 """
2514 Sets the value of :py:attr:`dropLast`.
2515 """
2516 return self._set(dropLast=value)
2517
2518 @since("3.0.0")
2519 def setInputCols(self, value):
2520 """
2521 Sets the value of :py:attr:`inputCols`.
2522 """
2523 return self._set(inputCols=value)
2524
2525 @since("3.0.0")
2526 def setOutputCols(self, value):
2527 """
2528 Sets the value of :py:attr:`outputCols`.
2529 """
2530 return self._set(outputCols=value)
2531
2532 @since("3.0.0")
2533 def setInputCol(self, value):
2534 """
2535 Sets the value of :py:attr:`inputCol`.
2536 """
2537 return self._set(inputCol=value)
2538
2539 @since("3.0.0")
2540 def setOutputCol(self, value):
2541 """
2542 Sets the value of :py:attr:`outputCol`.
2543 """
2544 return self._set(outputCol=value)
2545
2546 @since("3.0.0")
2547 def setHandleInvalid(self, value):
2548 """
2549 Sets the value of :py:attr:`handleInvalid`.
2550 """
2551 return self._set(handleInvalid=value)
2552
2553 @property
2554 @since("2.3.0")
2555 def categorySizes(self):
2556 """
2557 Original number of categories for each feature being encoded.
2558 The array contains one value for each input column, in order.
2559 """
2560 return self._call_java("categorySizes")
2561
2562
2563 @inherit_doc
2564 class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
2565 JavaMLWritable):
2566 """
2567 Perform feature expansion in a polynomial space. As said in `wikipedia of Polynomial Expansion
2568 <http://en.wikipedia.org/wiki/Polynomial_expansion>`_, "In mathematics, an
2569 expansion of a product of sums expresses it as a sum of products by using the fact that
2570 multiplication distributes over addition". Take a 2-variable feature vector as an example:
2571 `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`.
2572
2573 >>> from pyspark.ml.linalg import Vectors
2574 >>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])
2575 >>> px = PolynomialExpansion(degree=2)
2576 >>> px.setInputCol("dense")
2577 PolynomialExpansion...
2578 >>> px.setOutputCol("expanded")
2579 PolynomialExpansion...
2580 >>> px.transform(df).head().expanded
2581 DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
2582 >>> px.setParams(outputCol="test").transform(df).head().test
2583 DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
2584 >>> polyExpansionPath = temp_path + "/poly-expansion"
2585 >>> px.save(polyExpansionPath)
2586 >>> loadedPx = PolynomialExpansion.load(polyExpansionPath)
2587 >>> loadedPx.getDegree() == px.getDegree()
2588 True
2589
2590 .. versionadded:: 1.4.0
2591 """
2592
2593 degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)",
2594 typeConverter=TypeConverters.toInt)
2595
2596 @keyword_only
2597 def __init__(self, degree=2, inputCol=None, outputCol=None):
2598 """
2599 __init__(self, degree=2, inputCol=None, outputCol=None)
2600 """
2601 super(PolynomialExpansion, self).__init__()
2602 self._java_obj = self._new_java_obj(
2603 "org.apache.spark.ml.feature.PolynomialExpansion", self.uid)
2604 self._setDefault(degree=2)
2605 kwargs = self._input_kwargs
2606 self.setParams(**kwargs)
2607
2608 @keyword_only
2609 @since("1.4.0")
2610 def setParams(self, degree=2, inputCol=None, outputCol=None):
2611 """
2612 setParams(self, degree=2, inputCol=None, outputCol=None)
2613 Sets params for this PolynomialExpansion.
2614 """
2615 kwargs = self._input_kwargs
2616 return self._set(**kwargs)
2617
2618 @since("1.4.0")
2619 def setDegree(self, value):
2620 """
2621 Sets the value of :py:attr:`degree`.
2622 """
2623 return self._set(degree=value)
2624
2625 @since("1.4.0")
2626 def getDegree(self):
2627 """
2628 Gets the value of degree or its default value.
2629 """
2630 return self.getOrDefault(self.degree)
2631
2632 def setInputCol(self, value):
2633 """
2634 Sets the value of :py:attr:`inputCol`.
2635 """
2636 return self._set(inputCol=value)
2637
2638 def setOutputCol(self, value):
2639 """
2640 Sets the value of :py:attr:`outputCol`.
2641 """
2642 return self._set(outputCol=value)
2643
2644
2645 @inherit_doc
2646 class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
2647 HasHandleInvalid, HasRelativeError, JavaMLReadable, JavaMLWritable):
2648 """
2649 :py:class:`QuantileDiscretizer` takes a column with continuous features and outputs a column
2650 with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets`
2651 parameter. It is possible that the number of buckets used will be less than this value, for
2652 example, if there are too few distinct values of the input to create enough distinct quantiles.
2653 Since 3.0.0, :py:class:`QuantileDiscretizer` can map multiple columns at once by setting the
2654 :py:attr:`inputCols` parameter. If both of the :py:attr:`inputCol` and :py:attr:`inputCols`
2655 parameters are set, an Exception will be thrown. To specify the number of buckets for each
2656 column, the :py:attr:`numBucketsArray` parameter can be set, or if the number of buckets
2657 should be the same across columns, :py:attr:`numBuckets` can be set as a convenience.
2658
2659 NaN handling: Note also that
2660 :py:class:`QuantileDiscretizer` will raise an error when it finds NaN values in the dataset,
2661 but the user can also choose to either keep or remove NaN values within the dataset by setting
2662 :py:attr:`handleInvalid` parameter. If the user chooses to keep NaN values, they will be
2663 handled specially and placed into their own bucket, for example, if 4 buckets are used, then
2664 non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
2665
2666 Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
2667 :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description).
2668 The precision of the approximation can be controlled with the
2669 :py:attr:`relativeError` parameter.
2670 The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values.
2671
2672 >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
2673 >>> df1 = spark.createDataFrame(values, ["values"])
2674 >>> qds1 = QuantileDiscretizer(inputCol="values", outputCol="buckets")
2675 >>> qds1.setNumBuckets(2)
2676 QuantileDiscretizer...
2677 >>> qds1.setRelativeError(0.01)
2678 QuantileDiscretizer...
2679 >>> qds1.setHandleInvalid("error")
2680 QuantileDiscretizer...
2681 >>> qds1.getRelativeError()
2682 0.01
2683 >>> bucketizer = qds1.fit(df1)
2684 >>> qds1.setHandleInvalid("keep").fit(df1).transform(df1).count()
2685 6
2686 >>> qds1.setHandleInvalid("skip").fit(df1).transform(df1).count()
2687 4
2688 >>> splits = bucketizer.getSplits()
2689 >>> splits[0]
2690 -inf
2691 >>> print("%2.1f" % round(splits[1], 1))
2692 0.4
2693 >>> bucketed = bucketizer.transform(df1).head()
2694 >>> bucketed.buckets
2695 0.0
2696 >>> quantileDiscretizerPath = temp_path + "/quantile-discretizer"
2697 >>> qds1.save(quantileDiscretizerPath)
2698 >>> loadedQds = QuantileDiscretizer.load(quantileDiscretizerPath)
2699 >>> loadedQds.getNumBuckets() == qds1.getNumBuckets()
2700 True
2701 >>> inputs = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, 1.5),
2702 ... (float("nan"), float("nan")), (float("nan"), float("nan"))]
2703 >>> df2 = spark.createDataFrame(inputs, ["input1", "input2"])
2704 >>> qds2 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error", numBuckets=2,
2705 ... inputCols=["input1", "input2"], outputCols=["output1", "output2"])
2706 >>> qds2.getRelativeError()
2707 0.01
2708 >>> qds2.setHandleInvalid("keep").fit(df2).transform(df2).show()
2709 +------+------+-------+-------+
2710 |input1|input2|output1|output2|
2711 +------+------+-------+-------+
2712 | 0.1| 0.0| 0.0| 0.0|
2713 | 0.4| 1.0| 1.0| 1.0|
2714 | 1.2| 1.3| 1.0| 1.0|
2715 | 1.5| 1.5| 1.0| 1.0|
2716 | NaN| NaN| 2.0| 2.0|
2717 | NaN| NaN| 2.0| 2.0|
2718 +------+------+-------+-------+
2719 ...
2720 >>> qds3 = QuantileDiscretizer(relativeError=0.01, handleInvalid="error",
2721 ... numBucketsArray=[5, 10], inputCols=["input1", "input2"],
2722 ... outputCols=["output1", "output2"])
2723 >>> qds3.setHandleInvalid("skip").fit(df2).transform(df2).show()
2724 +------+------+-------+-------+
2725 |input1|input2|output1|output2|
2726 +------+------+-------+-------+
2727 | 0.1| 0.0| 1.0| 1.0|
2728 | 0.4| 1.0| 2.0| 2.0|
2729 | 1.2| 1.3| 3.0| 3.0|
2730 | 1.5| 1.5| 4.0| 4.0|
2731 +------+------+-------+-------+
2732 ...
2733
2734 .. versionadded:: 2.0.0
2735 """
2736
2737 numBuckets = Param(Params._dummy(), "numBuckets",
2738 "Maximum number of buckets (quantiles, or " +
2739 "categories) into which data points are grouped. Must be >= 2.",
2740 typeConverter=TypeConverters.toInt)
2741
2742 handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
2743 "Options are skip (filter out rows with invalid values), " +
2744 "error (throw an error), or keep (keep invalid values in a special " +
2745 "additional bucket). Note that in the multiple columns " +
2746 "case, the invalid handling is applied to all columns. That said " +
2747 "for 'error' it will throw an error if any invalids are found in " +
2748 "any columns, for 'skip' it will skip rows with any invalids in " +
2749 "any columns, etc.",
2750 typeConverter=TypeConverters.toString)
2751
2752 numBucketsArray = Param(Params._dummy(), "numBucketsArray", "Array of number of buckets " +
2753 "(quantiles, or categories) into which data points are grouped. " +
2754 "This is for multiple columns input. If transforming multiple " +
2755 "columns and numBucketsArray is not set, but numBuckets is set, " +
2756 "then numBuckets will be applied across all columns.",
2757 typeConverter=TypeConverters.toListInt)
2758
2759 @keyword_only
2760 def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
2761 handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
2762 """
2763 __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
2764 handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
2765 """
2766 super(QuantileDiscretizer, self).__init__()
2767 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
2768 self.uid)
2769 self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
2770 kwargs = self._input_kwargs
2771 self.setParams(**kwargs)
2772
2773 @keyword_only
2774 @since("2.0.0")
2775 def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
2776 handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
2777 """
2778 setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
2779 handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
2780 Set the params for the QuantileDiscretizer
2781 """
2782 kwargs = self._input_kwargs
2783 return self._set(**kwargs)
2784
2785 @since("2.0.0")
2786 def setNumBuckets(self, value):
2787 """
2788 Sets the value of :py:attr:`numBuckets`.
2789 """
2790 return self._set(numBuckets=value)
2791
2792 @since("2.0.0")
2793 def getNumBuckets(self):
2794 """
2795 Gets the value of numBuckets or its default value.
2796 """
2797 return self.getOrDefault(self.numBuckets)
2798
2799 @since("3.0.0")
2800 def setNumBucketsArray(self, value):
2801 """
2802 Sets the value of :py:attr:`numBucketsArray`.
2803 """
2804 return self._set(numBucketsArray=value)
2805
2806 @since("3.0.0")
2807 def getNumBucketsArray(self):
2808 """
2809 Gets the value of numBucketsArray or its default value.
2810 """
2811 return self.getOrDefault(self.numBucketsArray)
2812
2813 @since("2.0.0")
2814 def setRelativeError(self, value):
2815 """
2816 Sets the value of :py:attr:`relativeError`.
2817 """
2818 return self._set(relativeError=value)
2819
2820 def setInputCol(self, value):
2821 """
2822 Sets the value of :py:attr:`inputCol`.
2823 """
2824 return self._set(inputCol=value)
2825
2826 @since("3.0.0")
2827 def setInputCols(self, value):
2828 """
2829 Sets the value of :py:attr:`inputCols`.
2830 """
2831 return self._set(inputCols=value)
2832
2833 def setOutputCol(self, value):
2834 """
2835 Sets the value of :py:attr:`outputCol`.
2836 """
2837 return self._set(outputCol=value)
2838
2839 @since("3.0.0")
2840 def setOutputCols(self, value):
2841 """
2842 Sets the value of :py:attr:`outputCols`.
2843 """
2844 return self._set(outputCols=value)
2845
2846 def setHandleInvalid(self, value):
2847 """
2848 Sets the value of :py:attr:`handleInvalid`.
2849 """
2850 return self._set(handleInvalid=value)
2851
2852 def _create_model(self, java_model):
2853 """
2854 Private method to convert the java_model to a Python model.
2855 """
2856 if (self.isSet(self.inputCol)):
2857 return Bucketizer(splits=list(java_model.getSplits()),
2858 inputCol=self.getInputCol(),
2859 outputCol=self.getOutputCol(),
2860 handleInvalid=self.getHandleInvalid())
2861 else:
2862 splitsArrayList = [list(x) for x in list(java_model.getSplitsArray())]
2863 return Bucketizer(splitsArray=splitsArrayList,
2864 inputCols=self.getInputCols(),
2865 outputCols=self.getOutputCols(),
2866 handleInvalid=self.getHandleInvalid())
2867
2868
2869 class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
2870 """
2871 Params for :py:class:`RobustScaler` and :py:class:`RobustScalerModel`.
2872
2873 .. versionadded:: 3.0.0
2874 """
2875
2876 lower = Param(Params._dummy(), "lower", "Lower quantile to calculate quantile range",
2877 typeConverter=TypeConverters.toFloat)
2878 upper = Param(Params._dummy(), "upper", "Upper quantile to calculate quantile range",
2879 typeConverter=TypeConverters.toFloat)
2880 withCentering = Param(Params._dummy(), "withCentering", "Whether to center data with median",
2881 typeConverter=TypeConverters.toBoolean)
2882 withScaling = Param(Params._dummy(), "withScaling", "Whether to scale the data to "
2883 "quantile range", typeConverter=TypeConverters.toBoolean)
2884
2885 @since("3.0.0")
2886 def getLower(self):
2887 """
2888 Gets the value of lower or its default value.
2889 """
2890 return self.getOrDefault(self.lower)
2891
2892 @since("3.0.0")
2893 def getUpper(self):
2894 """
2895 Gets the value of upper or its default value.
2896 """
2897 return self.getOrDefault(self.upper)
2898
2899 @since("3.0.0")
2900 def getWithCentering(self):
2901 """
2902 Gets the value of withCentering or its default value.
2903 """
2904 return self.getOrDefault(self.withCentering)
2905
2906 @since("3.0.0")
2907 def getWithScaling(self):
2908 """
2909 Gets the value of withScaling or its default value.
2910 """
2911 return self.getOrDefault(self.withScaling)
2912
2913
2914 @inherit_doc
2915 class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWritable):
2916 """
2917 RobustScaler removes the median and scales the data according to the quantile range.
2918 The quantile range is by default IQR (Interquartile Range, quantile range between the
2919 1st quartile = 25th quantile and the 3rd quartile = 75th quantile) but can be configured.
2920 Centering and scaling happen independently on each feature by computing the relevant
2921 statistics on the samples in the training set. Median and quantile range are then
2922 stored to be used on later data using the transform method.
2923 Note that NaN values are ignored in the computation of medians and ranges.
2924
2925 >>> from pyspark.ml.linalg import Vectors
2926 >>> data = [(0, Vectors.dense([0.0, 0.0]),),
2927 ... (1, Vectors.dense([1.0, -1.0]),),
2928 ... (2, Vectors.dense([2.0, -2.0]),),
2929 ... (3, Vectors.dense([3.0, -3.0]),),
2930 ... (4, Vectors.dense([4.0, -4.0]),),]
2931 >>> df = spark.createDataFrame(data, ["id", "features"])
2932 >>> scaler = RobustScaler()
2933 >>> scaler.setInputCol("features")
2934 RobustScaler...
2935 >>> scaler.setOutputCol("scaled")
2936 RobustScaler...
2937 >>> model = scaler.fit(df)
2938 >>> model.setOutputCol("output")
2939 RobustScalerModel...
2940 >>> model.median
2941 DenseVector([2.0, -2.0])
2942 >>> model.range
2943 DenseVector([2.0, 2.0])
2944 >>> model.transform(df).collect()[1].output
2945 DenseVector([0.5, -0.5])
2946 >>> scalerPath = temp_path + "/robust-scaler"
2947 >>> scaler.save(scalerPath)
2948 >>> loadedScaler = RobustScaler.load(scalerPath)
2949 >>> loadedScaler.getWithCentering() == scaler.getWithCentering()
2950 True
2951 >>> loadedScaler.getWithScaling() == scaler.getWithScaling()
2952 True
2953 >>> modelPath = temp_path + "/robust-scaler-model"
2954 >>> model.save(modelPath)
2955 >>> loadedModel = RobustScalerModel.load(modelPath)
2956 >>> loadedModel.median == model.median
2957 True
2958 >>> loadedModel.range == model.range
2959 True
2960
2961 .. versionadded:: 3.0.0
2962 """
2963
2964 @keyword_only
2965 def __init__(self, lower=0.25, upper=0.75, withCentering=False, withScaling=True,
2966 inputCol=None, outputCol=None, relativeError=0.001):
2967 """
2968 __init__(self, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
2969 inputCol=None, outputCol=None, relativeError=0.001)
2970 """
2971 super(RobustScaler, self).__init__()
2972 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RobustScaler", self.uid)
2973 self._setDefault(lower=0.25, upper=0.75, withCentering=False, withScaling=True,
2974 relativeError=0.001)
2975 kwargs = self._input_kwargs
2976 self.setParams(**kwargs)
2977
2978 @keyword_only
2979 @since("3.0.0")
2980 def setParams(self, lower=0.25, upper=0.75, withCentering=False, withScaling=True,
2981 inputCol=None, outputCol=None, relativeError=0.001):
2982 """
2983 setParams(self, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
2984 inputCol=None, outputCol=None, relativeError=0.001)
2985 Sets params for this RobustScaler.
2986 """
2987 kwargs = self._input_kwargs
2988 return self._set(**kwargs)
2989
2990 @since("3.0.0")
2991 def setLower(self, value):
2992 """
2993 Sets the value of :py:attr:`lower`.
2994 """
2995 return self._set(lower=value)
2996
2997 @since("3.0.0")
2998 def setUpper(self, value):
2999 """
3000 Sets the value of :py:attr:`upper`.
3001 """
3002 return self._set(upper=value)
3003
3004 @since("3.0.0")
3005 def setWithCentering(self, value):
3006 """
3007 Sets the value of :py:attr:`withCentering`.
3008 """
3009 return self._set(withCentering=value)
3010
3011 @since("3.0.0")
3012 def setWithScaling(self, value):
3013 """
3014 Sets the value of :py:attr:`withScaling`.
3015 """
3016 return self._set(withScaling=value)
3017
3018 @since("3.0.0")
3019 def setInputCol(self, value):
3020 """
3021 Sets the value of :py:attr:`inputCol`.
3022 """
3023 return self._set(inputCol=value)
3024
3025 @since("3.0.0")
3026 def setOutputCol(self, value):
3027 """
3028 Sets the value of :py:attr:`outputCol`.
3029 """
3030 return self._set(outputCol=value)
3031
3032 @since("3.0.0")
3033 def setRelativeError(self, value):
3034 """
3035 Sets the value of :py:attr:`relativeError`.
3036 """
3037 return self._set(relativeError=value)
3038
3039 def _create_model(self, java_model):
3040 return RobustScalerModel(java_model)
3041
3042
3043 class RobustScalerModel(JavaModel, _RobustScalerParams, JavaMLReadable, JavaMLWritable):
3044 """
3045 Model fitted by :py:class:`RobustScaler`.
3046
3047 .. versionadded:: 3.0.0
3048 """
3049
3050 @since("3.0.0")
3051 def setInputCol(self, value):
3052 """
3053 Sets the value of :py:attr:`inputCol`.
3054 """
3055 return self._set(inputCol=value)
3056
3057 @since("3.0.0")
3058 def setOutputCol(self, value):
3059 """
3060 Sets the value of :py:attr:`outputCol`.
3061 """
3062 return self._set(outputCol=value)
3063
3064 @property
3065 @since("3.0.0")
3066 def median(self):
3067 """
3068 Median of the RobustScalerModel.
3069 """
3070 return self._call_java("median")
3071
3072 @property
3073 @since("3.0.0")
3074 def range(self):
3075 """
3076 Quantile range of the RobustScalerModel.
3077 """
3078 return self._call_java("range")
3079
3080
3081 @inherit_doc
3082 @ignore_unicode_prefix
3083 class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
3084 """
3085 A regex based tokenizer that extracts tokens either by using the
3086 provided regex pattern (in Java dialect) to split the text
3087 (default) or repeatedly matching the regex (if gaps is false).
3088 Optional parameters also allow filtering tokens using a minimal
3089 length.
3090 It returns an array of strings that can be empty.
3091
3092 >>> df = spark.createDataFrame([("A B c",)], ["text"])
3093 >>> reTokenizer = RegexTokenizer()
3094 >>> reTokenizer.setInputCol("text")
3095 RegexTokenizer...
3096 >>> reTokenizer.setOutputCol("words")
3097 RegexTokenizer...
3098 >>> reTokenizer.transform(df).head()
3099 Row(text=u'A B c', words=[u'a', u'b', u'c'])
3100 >>> # Change a parameter.
3101 >>> reTokenizer.setParams(outputCol="tokens").transform(df).head()
3102 Row(text=u'A B c', tokens=[u'a', u'b', u'c'])
3103 >>> # Temporarily modify a parameter.
3104 >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head()
3105 Row(text=u'A B c', words=[u'a', u'b', u'c'])
3106 >>> reTokenizer.transform(df).head()
3107 Row(text=u'A B c', tokens=[u'a', u'b', u'c'])
3108 >>> # Must use keyword arguments to specify params.
3109 >>> reTokenizer.setParams("text")
3110 Traceback (most recent call last):
3111 ...
3112 TypeError: Method setParams forces keyword arguments.
3113 >>> regexTokenizerPath = temp_path + "/regex-tokenizer"
3114 >>> reTokenizer.save(regexTokenizerPath)
3115 >>> loadedReTokenizer = RegexTokenizer.load(regexTokenizerPath)
3116 >>> loadedReTokenizer.getMinTokenLength() == reTokenizer.getMinTokenLength()
3117 True
3118 >>> loadedReTokenizer.getGaps() == reTokenizer.getGaps()
3119 True
3120
3121 .. versionadded:: 1.4.0
3122 """
3123
3124 minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)",
3125 typeConverter=TypeConverters.toInt)
3126 gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens " +
3127 "(False)")
3128 pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing",
3129 typeConverter=TypeConverters.toString)
3130 toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " +
3131 "lowercase before tokenizing", typeConverter=TypeConverters.toBoolean)
3132
3133 @keyword_only
3134 def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
3135 outputCol=None, toLowercase=True):
3136 """
3137 __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
3138 outputCol=None, toLowercase=True)
3139 """
3140 super(RegexTokenizer, self).__init__()
3141 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid)
3142 self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True)
3143 kwargs = self._input_kwargs
3144 self.setParams(**kwargs)
3145
3146 @keyword_only
3147 @since("1.4.0")
3148 def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
3149 outputCol=None, toLowercase=True):
3150 """
3151 setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
3152 outputCol=None, toLowercase=True)
3153 Sets params for this RegexTokenizer.
3154 """
3155 kwargs = self._input_kwargs
3156 return self._set(**kwargs)
3157
3158 @since("1.4.0")
3159 def setMinTokenLength(self, value):
3160 """
3161 Sets the value of :py:attr:`minTokenLength`.
3162 """
3163 return self._set(minTokenLength=value)
3164
3165 @since("1.4.0")
3166 def getMinTokenLength(self):
3167 """
3168 Gets the value of minTokenLength or its default value.
3169 """
3170 return self.getOrDefault(self.minTokenLength)
3171
3172 @since("1.4.0")
3173 def setGaps(self, value):
3174 """
3175 Sets the value of :py:attr:`gaps`.
3176 """
3177 return self._set(gaps=value)
3178
3179 @since("1.4.0")
3180 def getGaps(self):
3181 """
3182 Gets the value of gaps or its default value.
3183 """
3184 return self.getOrDefault(self.gaps)
3185
3186 @since("1.4.0")
3187 def setPattern(self, value):
3188 """
3189 Sets the value of :py:attr:`pattern`.
3190 """
3191 return self._set(pattern=value)
3192
3193 @since("1.4.0")
3194 def getPattern(self):
3195 """
3196 Gets the value of pattern or its default value.
3197 """
3198 return self.getOrDefault(self.pattern)
3199
3200 @since("2.0.0")
3201 def setToLowercase(self, value):
3202 """
3203 Sets the value of :py:attr:`toLowercase`.
3204 """
3205 return self._set(toLowercase=value)
3206
3207 @since("2.0.0")
3208 def getToLowercase(self):
3209 """
3210 Gets the value of toLowercase or its default value.
3211 """
3212 return self.getOrDefault(self.toLowercase)
3213
3214 def setInputCol(self, value):
3215 """
3216 Sets the value of :py:attr:`inputCol`.
3217 """
3218 return self._set(inputCol=value)
3219
3220 def setOutputCol(self, value):
3221 """
3222 Sets the value of :py:attr:`outputCol`.
3223 """
3224 return self._set(outputCol=value)
3225
3226
3227 @inherit_doc
3228 class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
3229 """
3230 Implements the transforms which are defined by SQL statement.
3231 Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
3232 where '__THIS__' represents the underlying table of the input dataset.
3233
3234 >>> df = spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"])
3235 >>> sqlTrans = SQLTransformer(
3236 ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
3237 >>> sqlTrans.transform(df).head()
3238 Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0)
3239 >>> sqlTransformerPath = temp_path + "/sql-transformer"
3240 >>> sqlTrans.save(sqlTransformerPath)
3241 >>> loadedSqlTrans = SQLTransformer.load(sqlTransformerPath)
3242 >>> loadedSqlTrans.getStatement() == sqlTrans.getStatement()
3243 True
3244
3245 .. versionadded:: 1.6.0
3246 """
3247
3248 statement = Param(Params._dummy(), "statement", "SQL statement",
3249 typeConverter=TypeConverters.toString)
3250
3251 @keyword_only
3252 def __init__(self, statement=None):
3253 """
3254 __init__(self, statement=None)
3255 """
3256 super(SQLTransformer, self).__init__()
3257 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid)
3258 kwargs = self._input_kwargs
3259 self.setParams(**kwargs)
3260
3261 @keyword_only
3262 @since("1.6.0")
3263 def setParams(self, statement=None):
3264 """
3265 setParams(self, statement=None)
3266 Sets params for this SQLTransformer.
3267 """
3268 kwargs = self._input_kwargs
3269 return self._set(**kwargs)
3270
3271 @since("1.6.0")
3272 def setStatement(self, value):
3273 """
3274 Sets the value of :py:attr:`statement`.
3275 """
3276 return self._set(statement=value)
3277
3278 @since("1.6.0")
3279 def getStatement(self):
3280 """
3281 Gets the value of statement or its default value.
3282 """
3283 return self.getOrDefault(self.statement)
3284
3285
3286 class _StandardScalerParams(HasInputCol, HasOutputCol):
3287 """
3288 Params for :py:class:`StandardScaler` and :py:class:`StandardScalerModel`.
3289
3290 .. versionadded:: 3.0.0
3291 """
3292
3293 withMean = Param(Params._dummy(), "withMean", "Center data with mean",
3294 typeConverter=TypeConverters.toBoolean)
3295 withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation",
3296 typeConverter=TypeConverters.toBoolean)
3297
3298 @since("1.4.0")
3299 def getWithMean(self):
3300 """
3301 Gets the value of withMean or its default value.
3302 """
3303 return self.getOrDefault(self.withMean)
3304
3305 @since("1.4.0")
3306 def getWithStd(self):
3307 """
3308 Gets the value of withStd or its default value.
3309 """
3310 return self.getOrDefault(self.withStd)
3311
3312
3313 @inherit_doc
3314 class StandardScaler(JavaEstimator, _StandardScalerParams, JavaMLReadable, JavaMLWritable):
3315 """
3316 Standardizes features by removing the mean and scaling to unit variance using column summary
3317 statistics on the samples in the training set.
3318
3319 The "unit std" is computed using the `corrected sample standard deviation \
3320 <https://en.wikipedia.org/wiki/Standard_deviation#Corrected_sample_standard_deviation>`_,
3321 which is computed as the square root of the unbiased sample variance.
3322
3323 >>> from pyspark.ml.linalg import Vectors
3324 >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
3325 >>> standardScaler = StandardScaler()
3326 >>> standardScaler.setInputCol("a")
3327 StandardScaler...
3328 >>> standardScaler.setOutputCol("scaled")
3329 StandardScaler...
3330 >>> model = standardScaler.fit(df)
3331 >>> model.getInputCol()
3332 'a'
3333 >>> model.setOutputCol("output")
3334 StandardScalerModel...
3335 >>> model.mean
3336 DenseVector([1.0])
3337 >>> model.std
3338 DenseVector([1.4142])
3339 >>> model.transform(df).collect()[1].output
3340 DenseVector([1.4142])
3341 >>> standardScalerPath = temp_path + "/standard-scaler"
3342 >>> standardScaler.save(standardScalerPath)
3343 >>> loadedStandardScaler = StandardScaler.load(standardScalerPath)
3344 >>> loadedStandardScaler.getWithMean() == standardScaler.getWithMean()
3345 True
3346 >>> loadedStandardScaler.getWithStd() == standardScaler.getWithStd()
3347 True
3348 >>> modelPath = temp_path + "/standard-scaler-model"
3349 >>> model.save(modelPath)
3350 >>> loadedModel = StandardScalerModel.load(modelPath)
3351 >>> loadedModel.std == model.std
3352 True
3353 >>> loadedModel.mean == model.mean
3354 True
3355
3356 .. versionadded:: 1.4.0
3357 """
3358
3359 @keyword_only
3360 def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None):
3361 """
3362 __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None)
3363 """
3364 super(StandardScaler, self).__init__()
3365 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid)
3366 self._setDefault(withMean=False, withStd=True)
3367 kwargs = self._input_kwargs
3368 self.setParams(**kwargs)
3369
3370 @keyword_only
3371 @since("1.4.0")
3372 def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None):
3373 """
3374 setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None)
3375 Sets params for this StandardScaler.
3376 """
3377 kwargs = self._input_kwargs
3378 return self._set(**kwargs)
3379
3380 @since("1.4.0")
3381 def setWithMean(self, value):
3382 """
3383 Sets the value of :py:attr:`withMean`.
3384 """
3385 return self._set(withMean=value)
3386
3387 @since("1.4.0")
3388 def setWithStd(self, value):
3389 """
3390 Sets the value of :py:attr:`withStd`.
3391 """
3392 return self._set(withStd=value)
3393
3394 def setInputCol(self, value):
3395 """
3396 Sets the value of :py:attr:`inputCol`.
3397 """
3398 return self._set(inputCol=value)
3399
3400 def setOutputCol(self, value):
3401 """
3402 Sets the value of :py:attr:`outputCol`.
3403 """
3404 return self._set(outputCol=value)
3405
3406 def _create_model(self, java_model):
3407 return StandardScalerModel(java_model)
3408
3409
3410 class StandardScalerModel(JavaModel, _StandardScalerParams, JavaMLReadable, JavaMLWritable):
3411 """
3412 Model fitted by :py:class:`StandardScaler`.
3413
3414 .. versionadded:: 1.4.0
3415 """
3416
3417 def setInputCol(self, value):
3418 """
3419 Sets the value of :py:attr:`inputCol`.
3420 """
3421 return self._set(inputCol=value)
3422
3423 def setOutputCol(self, value):
3424 """
3425 Sets the value of :py:attr:`outputCol`.
3426 """
3427 return self._set(outputCol=value)
3428
3429 @property
3430 @since("2.0.0")
3431 def std(self):
3432 """
3433 Standard deviation of the StandardScalerModel.
3434 """
3435 return self._call_java("std")
3436
3437 @property
3438 @since("2.0.0")
3439 def mean(self):
3440 """
3441 Mean of the StandardScalerModel.
3442 """
3443 return self._call_java("mean")
3444
3445
3446 class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol,
3447 HasInputCols, HasOutputCols):
3448 """
3449 Params for :py:class:`StringIndexer` and :py:class:`StringIndexerModel`.
3450 """
3451
3452 stringOrderType = Param(Params._dummy(), "stringOrderType",
3453 "How to order labels of string column. The first label after " +
3454 "ordering is assigned an index of 0. Supported options: " +
3455 "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. " +
3456 "Default is frequencyDesc. In case of equal frequency when " +
3457 "under frequencyDesc/Asc, the strings are further sorted " +
3458 "alphabetically",
3459 typeConverter=TypeConverters.toString)
3460
3461 handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
3462 "or NULL values) in features and label column of string type. " +
3463 "Options are 'skip' (filter out rows with invalid data), " +
3464 "error (throw an error), or 'keep' (put invalid data " +
3465 "in a special additional bucket, at index numLabels).",
3466 typeConverter=TypeConverters.toString)
3467
3468 def __init__(self, *args):
3469 super(_StringIndexerParams, self).__init__(*args)
3470 self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
3471
3472 @since("2.3.0")
3473 def getStringOrderType(self):
3474 """
3475 Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
3476 """
3477 return self.getOrDefault(self.stringOrderType)
3478
3479
3480 @inherit_doc
3481 class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
3482 """
3483 A label indexer that maps a string column of labels to an ML column of label indices.
3484 If the input column is numeric, we cast it to string and index the string values.
3485 The indices are in [0, numLabels). By default, this is ordered by label frequencies
3486 so the most frequent label gets index 0. The ordering behavior is controlled by
3487 setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'.
3488
3489 >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",
3490 ... stringOrderType="frequencyDesc")
3491 >>> stringIndexer.setHandleInvalid("error")
3492 StringIndexer...
3493 >>> model = stringIndexer.fit(stringIndDf)
3494 >>> model.setHandleInvalid("error")
3495 StringIndexerModel...
3496 >>> td = model.transform(stringIndDf)
3497 >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
3498 ... key=lambda x: x[0])
3499 [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
3500 >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels)
3501 >>> itd = inverter.transform(td)
3502 >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
3503 ... key=lambda x: x[0])
3504 [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
3505 >>> stringIndexerPath = temp_path + "/string-indexer"
3506 >>> stringIndexer.save(stringIndexerPath)
3507 >>> loadedIndexer = StringIndexer.load(stringIndexerPath)
3508 >>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid()
3509 True
3510 >>> modelPath = temp_path + "/string-indexer-model"
3511 >>> model.save(modelPath)
3512 >>> loadedModel = StringIndexerModel.load(modelPath)
3513 >>> loadedModel.labels == model.labels
3514 True
3515 >>> indexToStringPath = temp_path + "/index-to-string"
3516 >>> inverter.save(indexToStringPath)
3517 >>> loadedInverter = IndexToString.load(indexToStringPath)
3518 >>> loadedInverter.getLabels() == inverter.getLabels()
3519 True
3520 >>> stringIndexer.getStringOrderType()
3521 'frequencyDesc'
3522 >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
3523 ... stringOrderType="alphabetDesc")
3524 >>> model = stringIndexer.fit(stringIndDf)
3525 >>> td = model.transform(stringIndDf)
3526 >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
3527 ... key=lambda x: x[0])
3528 [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
3529 >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
3530 ... inputCol="label", outputCol="indexed", handleInvalid="error")
3531 >>> result = fromlabelsModel.transform(stringIndDf)
3532 >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
3533 ... key=lambda x: x[0])
3534 [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
3535 >>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),
3536 ... Row(id=1, label1="b", label2="f"),
3537 ... Row(id=2, label1="c", label2="e"),
3538 ... Row(id=3, label1="a", label2="f"),
3539 ... Row(id=4, label1="a", label2="f"),
3540 ... Row(id=5, label1="c", label2="f")], 3)
3541 >>> multiRowDf = spark.createDataFrame(testData)
3542 >>> inputs = ["label1", "label2"]
3543 >>> outputs = ["index1", "index2"]
3544 >>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)
3545 >>> model = stringIndexer.fit(multiRowDf)
3546 >>> result = model.transform(multiRowDf)
3547 >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
3548 ... result.index2).collect()]), key=lambda x: x[0])
3549 [(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)]
3550 >>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], ["e", "f"]],
3551 ... inputCols=inputs, outputCols=outputs)
3552 >>> result = fromlabelsModel.transform(multiRowDf)
3553 >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
3554 ... result.index2).collect()]), key=lambda x: x[0])
3555 [(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.0, 1.0)]
3556
3557 .. versionadded:: 1.4.0
3558 """
3559
3560 @keyword_only
3561 def __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
3562 handleInvalid="error", stringOrderType="frequencyDesc"):
3563 """
3564 __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
3565 handleInvalid="error", stringOrderType="frequencyDesc")
3566 """
3567 super(StringIndexer, self).__init__()
3568 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
3569 kwargs = self._input_kwargs
3570 self.setParams(**kwargs)
3571
3572 @keyword_only
3573 @since("1.4.0")
3574 def setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
3575 handleInvalid="error", stringOrderType="frequencyDesc"):
3576 """
3577 setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
3578 handleInvalid="error", stringOrderType="frequencyDesc")
3579 Sets params for this StringIndexer.
3580 """
3581 kwargs = self._input_kwargs
3582 return self._set(**kwargs)
3583
3584 def _create_model(self, java_model):
3585 return StringIndexerModel(java_model)
3586
3587 @since("2.3.0")
3588 def setStringOrderType(self, value):
3589 """
3590 Sets the value of :py:attr:`stringOrderType`.
3591 """
3592 return self._set(stringOrderType=value)
3593
3594 def setInputCol(self, value):
3595 """
3596 Sets the value of :py:attr:`inputCol`.
3597 """
3598 return self._set(inputCol=value)
3599
3600 @since("3.0.0")
3601 def setInputCols(self, value):
3602 """
3603 Sets the value of :py:attr:`inputCols`.
3604 """
3605 return self._set(inputCols=value)
3606
3607 def setOutputCol(self, value):
3608 """
3609 Sets the value of :py:attr:`outputCol`.
3610 """
3611 return self._set(outputCol=value)
3612
3613 @since("3.0.0")
3614 def setOutputCols(self, value):
3615 """
3616 Sets the value of :py:attr:`outputCols`.
3617 """
3618 return self._set(outputCols=value)
3619
3620 def setHandleInvalid(self, value):
3621 """
3622 Sets the value of :py:attr:`handleInvalid`.
3623 """
3624 return self._set(handleInvalid=value)
3625
3626
3627 class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
3628 """
3629 Model fitted by :py:class:`StringIndexer`.
3630
3631 .. versionadded:: 1.4.0
3632 """
3633
3634 def setInputCol(self, value):
3635 """
3636 Sets the value of :py:attr:`inputCol`.
3637 """
3638 return self._set(inputCol=value)
3639
3640 @since("3.0.0")
3641 def setInputCols(self, value):
3642 """
3643 Sets the value of :py:attr:`inputCols`.
3644 """
3645 return self._set(inputCols=value)
3646
3647 def setOutputCol(self, value):
3648 """
3649 Sets the value of :py:attr:`outputCol`.
3650 """
3651 return self._set(outputCol=value)
3652
3653 @since("3.0.0")
3654 def setOutputCols(self, value):
3655 """
3656 Sets the value of :py:attr:`outputCols`.
3657 """
3658 return self._set(outputCols=value)
3659
3660 @since("2.4.0")
3661 def setHandleInvalid(self, value):
3662 """
3663 Sets the value of :py:attr:`handleInvalid`.
3664 """
3665 return self._set(handleInvalid=value)
3666
3667 @classmethod
3668 @since("2.4.0")
3669 def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None):
3670 """
3671 Construct the model directly from an array of label strings,
3672 requires an active SparkContext.
3673 """
3674 sc = SparkContext._active_spark_context
3675 java_class = sc._gateway.jvm.java.lang.String
3676 jlabels = StringIndexerModel._new_java_array(labels, java_class)
3677 model = StringIndexerModel._create_from_java_class(
3678 "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
3679 model.setInputCol(inputCol)
3680 if outputCol is not None:
3681 model.setOutputCol(outputCol)
3682 if handleInvalid is not None:
3683 model.setHandleInvalid(handleInvalid)
3684 return model
3685
3686 @classmethod
3687 @since("3.0.0")
3688 def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None,
3689 handleInvalid=None):
3690 """
3691 Construct the model directly from an array of array of label strings,
3692 requires an active SparkContext.
3693 """
3694 sc = SparkContext._active_spark_context
3695 java_class = sc._gateway.jvm.java.lang.String
3696 jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
3697 model = StringIndexerModel._create_from_java_class(
3698 "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
3699 model.setInputCols(inputCols)
3700 if outputCols is not None:
3701 model.setOutputCols(outputCols)
3702 if handleInvalid is not None:
3703 model.setHandleInvalid(handleInvalid)
3704 return model
3705
3706 @property
3707 @since("1.5.0")
3708 def labels(self):
3709 """
3710 Ordered list of labels, corresponding to indices to be assigned.
3711 """
3712 return self._call_java("labels")
3713
3714
3715 @inherit_doc
3716 class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
3717 """
3718 A :py:class:`Transformer` that maps a column of indices back to a new column of
3719 corresponding string values.
3720 The index-string mapping is either from the ML attributes of the input column,
3721 or from user-supplied labels (which take precedence over ML attributes).
3722 See :class:`StringIndexer` for converting strings into indices.
3723
3724 .. versionadded:: 1.6.0
3725 """
3726
3727 labels = Param(Params._dummy(), "labels",
3728 "Optional array of labels specifying index-string mapping." +
3729 " If not provided or if empty, then metadata from inputCol is used instead.",
3730 typeConverter=TypeConverters.toListString)
3731
3732 @keyword_only
3733 def __init__(self, inputCol=None, outputCol=None, labels=None):
3734 """
3735 __init__(self, inputCol=None, outputCol=None, labels=None)
3736 """
3737 super(IndexToString, self).__init__()
3738 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
3739 self.uid)
3740 kwargs = self._input_kwargs
3741 self.setParams(**kwargs)
3742
3743 @keyword_only
3744 @since("1.6.0")
3745 def setParams(self, inputCol=None, outputCol=None, labels=None):
3746 """
3747 setParams(self, inputCol=None, outputCol=None, labels=None)
3748 Sets params for this IndexToString.
3749 """
3750 kwargs = self._input_kwargs
3751 return self._set(**kwargs)
3752
3753 @since("1.6.0")
3754 def setLabels(self, value):
3755 """
3756 Sets the value of :py:attr:`labels`.
3757 """
3758 return self._set(labels=value)
3759
3760 @since("1.6.0")
3761 def getLabels(self):
3762 """
3763 Gets the value of :py:attr:`labels` or its default value.
3764 """
3765 return self.getOrDefault(self.labels)
3766
3767 def setInputCol(self, value):
3768 """
3769 Sets the value of :py:attr:`inputCol`.
3770 """
3771 return self._set(inputCol=value)
3772
3773 def setOutputCol(self, value):
3774 """
3775 Sets the value of :py:attr:`outputCol`.
3776 """
3777 return self._set(outputCol=value)
3778
3779
3780 class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
3781 JavaMLReadable, JavaMLWritable):
3782 """
3783 A feature transformer that filters out stop words from input.
3784 Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
3785 the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and
3786 :py:attr:`inputCols` parameters are set, an Exception will be thrown.
3787
3788 .. note:: null values from input array are preserved unless adding null to stopWords explicitly.
3789
3790 >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"])
3791 >>> remover = StopWordsRemover(stopWords=["b"])
3792 >>> remover.setInputCol("text")
3793 StopWordsRemover...
3794 >>> remover.setOutputCol("words")
3795 StopWordsRemover...
3796 >>> remover.transform(df).head().words == ['a', 'c']
3797 True
3798 >>> stopWordsRemoverPath = temp_path + "/stopwords-remover"
3799 >>> remover.save(stopWordsRemoverPath)
3800 >>> loadedRemover = StopWordsRemover.load(stopWordsRemoverPath)
3801 >>> loadedRemover.getStopWords() == remover.getStopWords()
3802 True
3803 >>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
3804 True
3805 >>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"])
3806 >>> remover2 = StopWordsRemover(stopWords=["b"])
3807 >>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])
3808 StopWordsRemover...
3809 >>> remover2.transform(df2).show()
3810 +---------+------+------+------+
3811 | text1| text2|words1|words2|
3812 +---------+------+------+------+
3813 |[a, b, c]|[a, b]|[a, c]| [a]|
3814 +---------+------+------+------+
3815 ...
3816
3817 .. versionadded:: 1.6.0
3818 """
3819
3820 stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out",
3821 typeConverter=TypeConverters.toListString)
3822 caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
3823 "comparison over the stop words", typeConverter=TypeConverters.toBoolean)
3824 locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " +
3825 "is true", typeConverter=TypeConverters.toString)
3826
3827 @keyword_only
3828 def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
3829 locale=None, inputCols=None, outputCols=None):
3830 """
3831 __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
3832 locale=None, inputCols=None, outputCols=None)
3833 """
3834 super(StopWordsRemover, self).__init__()
3835 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
3836 self.uid)
3837 self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
3838 caseSensitive=False, locale=self._java_obj.getLocale())
3839 kwargs = self._input_kwargs
3840 self.setParams(**kwargs)
3841
3842 @keyword_only
3843 @since("1.6.0")
3844 def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
3845 locale=None, inputCols=None, outputCols=None):
3846 """
3847 setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
3848 locale=None, inputCols=None, outputCols=None)
3849 Sets params for this StopWordRemover.
3850 """
3851 kwargs = self._input_kwargs
3852 return self._set(**kwargs)
3853
3854 @since("1.6.0")
3855 def setStopWords(self, value):
3856 """
3857 Sets the value of :py:attr:`stopWords`.
3858 """
3859 return self._set(stopWords=value)
3860
3861 @since("1.6.0")
3862 def getStopWords(self):
3863 """
3864 Gets the value of :py:attr:`stopWords` or its default value.
3865 """
3866 return self.getOrDefault(self.stopWords)
3867
3868 @since("1.6.0")
3869 def setCaseSensitive(self, value):
3870 """
3871 Sets the value of :py:attr:`caseSensitive`.
3872 """
3873 return self._set(caseSensitive=value)
3874
3875 @since("1.6.0")
3876 def getCaseSensitive(self):
3877 """
3878 Gets the value of :py:attr:`caseSensitive` or its default value.
3879 """
3880 return self.getOrDefault(self.caseSensitive)
3881
3882 @since("2.4.0")
3883 def setLocale(self, value):
3884 """
3885 Sets the value of :py:attr:`locale`.
3886 """
3887 return self._set(locale=value)
3888
3889 @since("2.4.0")
3890 def getLocale(self):
3891 """
3892 Gets the value of :py:attr:`locale`.
3893 """
3894 return self.getOrDefault(self.locale)
3895
3896 def setInputCol(self, value):
3897 """
3898 Sets the value of :py:attr:`inputCol`.
3899 """
3900 return self._set(inputCol=value)
3901
3902 def setOutputCol(self, value):
3903 """
3904 Sets the value of :py:attr:`outputCol`.
3905 """
3906 return self._set(outputCol=value)
3907
3908 @since("3.0.0")
3909 def setInputCols(self, value):
3910 """
3911 Sets the value of :py:attr:`inputCols`.
3912 """
3913 return self._set(inputCols=value)
3914
3915 @since("3.0.0")
3916 def setOutputCols(self, value):
3917 """
3918 Sets the value of :py:attr:`outputCols`.
3919 """
3920 return self._set(outputCols=value)
3921
3922 @staticmethod
3923 @since("2.0.0")
3924 def loadDefaultStopWords(language):
3925 """
3926 Loads the default stop words for the given language.
3927 Supported languages: danish, dutch, english, finnish, french, german, hungarian,
3928 italian, norwegian, portuguese, russian, spanish, swedish, turkish
3929 """
3930 stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWordsRemover
3931 return list(stopWordsObj.loadDefaultStopWords(language))
3932
3933
3934 @inherit_doc
3935 @ignore_unicode_prefix
3936 class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
3937 """
3938 A tokenizer that converts the input string to lowercase and then
3939 splits it by white spaces.
3940
3941 >>> df = spark.createDataFrame([("a b c",)], ["text"])
3942 >>> tokenizer = Tokenizer(outputCol="words")
3943 >>> tokenizer.setInputCol("text")
3944 Tokenizer...
3945 >>> tokenizer.transform(df).head()
3946 Row(text=u'a b c', words=[u'a', u'b', u'c'])
3947 >>> # Change a parameter.
3948 >>> tokenizer.setParams(outputCol="tokens").transform(df).head()
3949 Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
3950 >>> # Temporarily modify a parameter.
3951 >>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
3952 Row(text=u'a b c', words=[u'a', u'b', u'c'])
3953 >>> tokenizer.transform(df).head()
3954 Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
3955 >>> # Must use keyword arguments to specify params.
3956 >>> tokenizer.setParams("text")
3957 Traceback (most recent call last):
3958 ...
3959 TypeError: Method setParams forces keyword arguments.
3960 >>> tokenizerPath = temp_path + "/tokenizer"
3961 >>> tokenizer.save(tokenizerPath)
3962 >>> loadedTokenizer = Tokenizer.load(tokenizerPath)
3963 >>> loadedTokenizer.transform(df).head().tokens == tokenizer.transform(df).head().tokens
3964 True
3965
3966 .. versionadded:: 1.3.0
3967 """
3968
3969 @keyword_only
3970 def __init__(self, inputCol=None, outputCol=None):
3971 """
3972 __init__(self, inputCol=None, outputCol=None)
3973 """
3974 super(Tokenizer, self).__init__()
3975 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid)
3976 kwargs = self._input_kwargs
3977 self.setParams(**kwargs)
3978
3979 @keyword_only
3980 @since("1.3.0")
3981 def setParams(self, inputCol=None, outputCol=None):
3982 """
3983 setParams(self, inputCol=None, outputCol=None)
3984 Sets params for this Tokenizer.
3985 """
3986 kwargs = self._input_kwargs
3987 return self._set(**kwargs)
3988
3989 def setInputCol(self, value):
3990 """
3991 Sets the value of :py:attr:`inputCol`.
3992 """
3993 return self._set(inputCol=value)
3994
3995 def setOutputCol(self, value):
3996 """
3997 Sets the value of :py:attr:`outputCol`.
3998 """
3999 return self._set(outputCol=value)
4000
4001
4002 @inherit_doc
4003 class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
4004 JavaMLWritable):
4005 """
4006 A feature transformer that merges multiple columns into a vector column.
4007
4008 >>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"])
4009 >>> vecAssembler = VectorAssembler(outputCol="features")
4010 >>> vecAssembler.setInputCols(["a", "b", "c"])
4011 VectorAssembler...
4012 >>> vecAssembler.transform(df).head().features
4013 DenseVector([1.0, 0.0, 3.0])
4014 >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
4015 DenseVector([1.0, 0.0, 3.0])
4016 >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
4017 >>> vecAssembler.transform(df, params).head().vector
4018 DenseVector([0.0, 1.0])
4019 >>> vectorAssemblerPath = temp_path + "/vector-assembler"
4020 >>> vecAssembler.save(vectorAssemblerPath)
4021 >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
4022 >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
4023 True
4024 >>> dfWithNullsAndNaNs = spark.createDataFrame(
4025 ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
4026 >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
4027 ... handleInvalid="keep")
4028 >>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
4029 +---+---+----+-------------+
4030 | a| b| c| features|
4031 +---+---+----+-------------+
4032 |1.0|2.0|null|[1.0,2.0,NaN]|
4033 |3.0|NaN| 4.0|[3.0,NaN,4.0]|
4034 |5.0|6.0| 7.0|[5.0,6.0,7.0]|
4035 +---+---+----+-------------+
4036 ...
4037 >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
4038 +---+---+---+-------------+
4039 | a| b| c| features|
4040 +---+---+---+-------------+
4041 |5.0|6.0|7.0|[5.0,6.0,7.0]|
4042 +---+---+---+-------------+
4043 ...
4044
4045 .. versionadded:: 1.4.0
4046 """
4047
4048 handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
4049 "and NaN values). Options are 'skip' (filter out rows with invalid " +
4050 "data), 'error' (throw an error), or 'keep' (return relevant number " +
4051 "of NaN in the output). Column lengths are taken from the size of ML " +
4052 "Attribute Group, which can be set using `VectorSizeHint` in a " +
4053 "pipeline before `VectorAssembler`. Column lengths can also be " +
4054 "inferred from first rows of the data since it is safe to do so but " +
4055 "only in case of 'error' or 'skip').",
4056 typeConverter=TypeConverters.toString)
4057
4058 @keyword_only
4059 def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"):
4060 """
4061 __init__(self, inputCols=None, outputCol=None, handleInvalid="error")
4062 """
4063 super(VectorAssembler, self).__init__()
4064 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
4065 self._setDefault(handleInvalid="error")
4066 kwargs = self._input_kwargs
4067 self.setParams(**kwargs)
4068
4069 @keyword_only
4070 @since("1.4.0")
4071 def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
4072 """
4073 setParams(self, inputCols=None, outputCol=None, handleInvalid="error")
4074 Sets params for this VectorAssembler.
4075 """
4076 kwargs = self._input_kwargs
4077 return self._set(**kwargs)
4078
4079 def setInputCols(self, value):
4080 """
4081 Sets the value of :py:attr:`inputCols`.
4082 """
4083 return self._set(inputCols=value)
4084
4085 def setOutputCol(self, value):
4086 """
4087 Sets the value of :py:attr:`outputCol`.
4088 """
4089 return self._set(outputCol=value)
4090
4091 def setHandleInvalid(self, value):
4092 """
4093 Sets the value of :py:attr:`handleInvalid`.
4094 """
4095 return self._set(handleInvalid=value)
4096
4097
4098 class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid):
4099 """
4100 Params for :py:class:`VectorIndexer` and :py:class:`VectorIndexerModel`.
4101
4102 .. versionadded:: 3.0.0
4103 """
4104
4105 maxCategories = Param(Params._dummy(), "maxCategories",
4106 "Threshold for the number of values a categorical feature can take " +
4107 "(>= 2). If a feature is found to have > maxCategories values, then " +
4108 "it is declared continuous.", typeConverter=TypeConverters.toInt)
4109
4110 handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " +
4111 "(unseen labels or NULL values). Options are 'skip' (filter out " +
4112 "rows with invalid data), 'error' (throw an error), or 'keep' (put " +
4113 "invalid data in a special additional bucket, at index of the number " +
4114 "of categories of the feature).",
4115 typeConverter=TypeConverters.toString)
4116
4117 @since("1.4.0")
4118 def getMaxCategories(self):
4119 """
4120 Gets the value of maxCategories or its default value.
4121 """
4122 return self.getOrDefault(self.maxCategories)
4123
4124
4125 @inherit_doc
4126 class VectorIndexer(JavaEstimator, _VectorIndexerParams, JavaMLReadable, JavaMLWritable):
4127 """
4128 Class for indexing categorical feature columns in a dataset of `Vector`.
4129
4130 This has 2 usage modes:
4131 - Automatically identify categorical features (default behavior)
4132 - This helps process a dataset of unknown vectors into a dataset with some continuous
4133 features and some categorical features. The choice between continuous and categorical
4134 is based upon a maxCategories parameter.
4135 - Set maxCategories to the maximum number of categorical any categorical feature should
4136 have.
4137 - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}.
4138 If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1},
4139 and feature 1 will be declared continuous.
4140 - Index all features, if all features are categorical
4141 - If maxCategories is set to be very large, then this will build an index of unique
4142 values for all features.
4143 - Warning: This can cause problems if features are continuous since this will collect ALL
4144 unique values to the driver.
4145 - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}.
4146 If maxCategories >= 3, then both features will be declared categorical.
4147
4148 This returns a model which can transform categorical features to use 0-based indices.
4149
4150 Index stability:
4151 - This is not guaranteed to choose the same category index across multiple runs.
4152 - If a categorical feature includes value 0, then this is guaranteed to map value 0 to
4153 index 0. This maintains vector sparsity.
4154 - More stability may be added in the future.
4155
4156 TODO: Future extensions: The following functionality is planned for the future:
4157 - Preserve metadata in transform; if a feature's metadata is already present,
4158 do not recompute.
4159 - Specify certain features to not index, either via a parameter or via existing metadata.
4160 - Add warning if a categorical feature has only 1 category.
4161
4162 >>> from pyspark.ml.linalg import Vectors
4163 >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
4164 ... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"])
4165 >>> indexer = VectorIndexer(maxCategories=2, inputCol="a")
4166 >>> indexer.setOutputCol("indexed")
4167 VectorIndexer...
4168 >>> model = indexer.fit(df)
4169 >>> indexer.getHandleInvalid()
4170 'error'
4171 >>> model.setOutputCol("output")
4172 VectorIndexerModel...
4173 >>> model.transform(df).head().output
4174 DenseVector([1.0, 0.0])
4175 >>> model.numFeatures
4176 2
4177 >>> model.categoryMaps
4178 {0: {0.0: 0, -1.0: 1}}
4179 >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test
4180 DenseVector([0.0, 1.0])
4181 >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}
4182 >>> model2 = indexer.fit(df, params)
4183 >>> model2.transform(df).head().vector
4184 DenseVector([1.0, 0.0])
4185 >>> vectorIndexerPath = temp_path + "/vector-indexer"
4186 >>> indexer.save(vectorIndexerPath)
4187 >>> loadedIndexer = VectorIndexer.load(vectorIndexerPath)
4188 >>> loadedIndexer.getMaxCategories() == indexer.getMaxCategories()
4189 True
4190 >>> modelPath = temp_path + "/vector-indexer-model"
4191 >>> model.save(modelPath)
4192 >>> loadedModel = VectorIndexerModel.load(modelPath)
4193 >>> loadedModel.numFeatures == model.numFeatures
4194 True
4195 >>> loadedModel.categoryMaps == model.categoryMaps
4196 True
4197 >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])
4198 >>> indexer.getHandleInvalid()
4199 'error'
4200 >>> model3 = indexer.setHandleInvalid("skip").fit(df)
4201 >>> model3.transform(dfWithInvalid).count()
4202 0
4203 >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)
4204 >>> model4.transform(dfWithInvalid).head().indexed
4205 DenseVector([2.0, 1.0])
4206
4207 .. versionadded:: 1.4.0
4208 """
4209
4210 @keyword_only
4211 def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
4212 """
4213 __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
4214 """
4215 super(VectorIndexer, self).__init__()
4216 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
4217 self._setDefault(maxCategories=20, handleInvalid="error")
4218 kwargs = self._input_kwargs
4219 self.setParams(**kwargs)
4220
4221 @keyword_only
4222 @since("1.4.0")
4223 def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
4224 """
4225 setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
4226 Sets params for this VectorIndexer.
4227 """
4228 kwargs = self._input_kwargs
4229 return self._set(**kwargs)
4230
4231 @since("1.4.0")
4232 def setMaxCategories(self, value):
4233 """
4234 Sets the value of :py:attr:`maxCategories`.
4235 """
4236 return self._set(maxCategories=value)
4237
4238 def setInputCol(self, value):
4239 """
4240 Sets the value of :py:attr:`inputCol`.
4241 """
4242 return self._set(inputCol=value)
4243
4244 def setOutputCol(self, value):
4245 """
4246 Sets the value of :py:attr:`outputCol`.
4247 """
4248 return self._set(outputCol=value)
4249
4250 def setHandleInvalid(self, value):
4251 """
4252 Sets the value of :py:attr:`handleInvalid`.
4253 """
4254 return self._set(handleInvalid=value)
4255
4256 def _create_model(self, java_model):
4257 return VectorIndexerModel(java_model)
4258
4259
4260 class VectorIndexerModel(JavaModel, _VectorIndexerParams, JavaMLReadable, JavaMLWritable):
4261 """
4262 Model fitted by :py:class:`VectorIndexer`.
4263
4264 Transform categorical features to use 0-based indices instead of their original values.
4265 - Categorical features are mapped to indices.
4266 - Continuous features (columns) are left unchanged.
4267
4268 This also appends metadata to the output column, marking features as Numeric (continuous),
4269 Nominal (categorical), or Binary (either continuous or categorical).
4270 Non-ML metadata is not carried over from the input to the output column.
4271
4272 This maintains vector sparsity.
4273
4274 .. versionadded:: 1.4.0
4275 """
4276
4277 @since("3.0.0")
4278 def setInputCol(self, value):
4279 """
4280 Sets the value of :py:attr:`inputCol`.
4281 """
4282 return self._set(inputCol=value)
4283
4284 @since("3.0.0")
4285 def setOutputCol(self, value):
4286 """
4287 Sets the value of :py:attr:`outputCol`.
4288 """
4289 return self._set(outputCol=value)
4290
4291 @property
4292 @since("1.4.0")
4293 def numFeatures(self):
4294 """
4295 Number of features, i.e., length of Vectors which this transforms.
4296 """
4297 return self._call_java("numFeatures")
4298
4299 @property
4300 @since("1.4.0")
4301 def categoryMaps(self):
4302 """
4303 Feature value index. Keys are categorical feature indices (column indices).
4304 Values are maps from original features values to 0-based category indices.
4305 If a feature is not in this map, it is treated as continuous.
4306 """
4307 return self._call_java("javaCategoryMaps")
4308
4309
4310 @inherit_doc
4311 class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
4312 """
4313 This class takes a feature vector and outputs a new feature vector with a subarray
4314 of the original features.
4315
4316 The subset of features can be specified with either indices (`setIndices()`)
4317 or names (`setNames()`). At least one feature must be selected. Duplicate features
4318 are not allowed, so there can be no overlap between selected indices and names.
4319
4320 The output vector will order features with the selected indices first (in the order given),
4321 followed by the selected names (in the order given).
4322
4323 >>> from pyspark.ml.linalg import Vectors
4324 >>> df = spark.createDataFrame([
4325 ... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),),
4326 ... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),),
4327 ... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"])
4328 >>> vs = VectorSlicer(outputCol="sliced", indices=[1, 4])
4329 >>> vs.setInputCol("features")
4330 VectorSlicer...
4331 >>> vs.transform(df).head().sliced
4332 DenseVector([2.3, 1.0])
4333 >>> vectorSlicerPath = temp_path + "/vector-slicer"
4334 >>> vs.save(vectorSlicerPath)
4335 >>> loadedVs = VectorSlicer.load(vectorSlicerPath)
4336 >>> loadedVs.getIndices() == vs.getIndices()
4337 True
4338 >>> loadedVs.getNames() == vs.getNames()
4339 True
4340
4341 .. versionadded:: 1.6.0
4342 """
4343
4344 indices = Param(Params._dummy(), "indices", "An array of indices to select features from " +
4345 "a vector column. There can be no overlap with names.",
4346 typeConverter=TypeConverters.toListInt)
4347 names = Param(Params._dummy(), "names", "An array of feature names to select features from " +
4348 "a vector column. These names must be specified by ML " +
4349 "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " +
4350 "indices.", typeConverter=TypeConverters.toListString)
4351
4352 @keyword_only
4353 def __init__(self, inputCol=None, outputCol=None, indices=None, names=None):
4354 """
4355 __init__(self, inputCol=None, outputCol=None, indices=None, names=None)
4356 """
4357 super(VectorSlicer, self).__init__()
4358 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid)
4359 self._setDefault(indices=[], names=[])
4360 kwargs = self._input_kwargs
4361 self.setParams(**kwargs)
4362
4363 @keyword_only
4364 @since("1.6.0")
4365 def setParams(self, inputCol=None, outputCol=None, indices=None, names=None):
4366 """
4367 setParams(self, inputCol=None, outputCol=None, indices=None, names=None):
4368 Sets params for this VectorSlicer.
4369 """
4370 kwargs = self._input_kwargs
4371 return self._set(**kwargs)
4372
4373 @since("1.6.0")
4374 def setIndices(self, value):
4375 """
4376 Sets the value of :py:attr:`indices`.
4377 """
4378 return self._set(indices=value)
4379
4380 @since("1.6.0")
4381 def getIndices(self):
4382 """
4383 Gets the value of indices or its default value.
4384 """
4385 return self.getOrDefault(self.indices)
4386
4387 @since("1.6.0")
4388 def setNames(self, value):
4389 """
4390 Sets the value of :py:attr:`names`.
4391 """
4392 return self._set(names=value)
4393
4394 @since("1.6.0")
4395 def getNames(self):
4396 """
4397 Gets the value of names or its default value.
4398 """
4399 return self.getOrDefault(self.names)
4400
4401 def setInputCol(self, value):
4402 """
4403 Sets the value of :py:attr:`inputCol`.
4404 """
4405 return self._set(inputCol=value)
4406
4407 def setOutputCol(self, value):
4408 """
4409 Sets the value of :py:attr:`outputCol`.
4410 """
4411 return self._set(outputCol=value)
4412
4413
4414 class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol):
4415 """
4416 Params for :py:class:`Word2Vec` and :py:class:`Word2VecModel`.
4417
4418 .. versionadded:: 3.0.0
4419 """
4420
4421 vectorSize = Param(Params._dummy(), "vectorSize",
4422 "the dimension of codes after transforming from words",
4423 typeConverter=TypeConverters.toInt)
4424 numPartitions = Param(Params._dummy(), "numPartitions",
4425 "number of partitions for sentences of words",
4426 typeConverter=TypeConverters.toInt)
4427 minCount = Param(Params._dummy(), "minCount",
4428 "the minimum number of times a token must appear to be included in the " +
4429 "word2vec model's vocabulary", typeConverter=TypeConverters.toInt)
4430 windowSize = Param(Params._dummy(), "windowSize",
4431 "the window size (context words from [-window, window]). Default value is 5",
4432 typeConverter=TypeConverters.toInt)
4433 maxSentenceLength = Param(Params._dummy(), "maxSentenceLength",
4434 "Maximum length (in words) of each sentence in the input data. " +
4435 "Any sentence longer than this threshold will " +
4436 "be divided into chunks up to the size.",
4437 typeConverter=TypeConverters.toInt)
4438
4439 @since("1.4.0")
4440 def getVectorSize(self):
4441 """
4442 Gets the value of vectorSize or its default value.
4443 """
4444 return self.getOrDefault(self.vectorSize)
4445
4446 @since("1.4.0")
4447 def getNumPartitions(self):
4448 """
4449 Gets the value of numPartitions or its default value.
4450 """
4451 return self.getOrDefault(self.numPartitions)
4452
4453 @since("1.4.0")
4454 def getMinCount(self):
4455 """
4456 Gets the value of minCount or its default value.
4457 """
4458 return self.getOrDefault(self.minCount)
4459
4460 @since("2.0.0")
4461 def getWindowSize(self):
4462 """
4463 Gets the value of windowSize or its default value.
4464 """
4465 return self.getOrDefault(self.windowSize)
4466
4467 @since("2.0.0")
4468 def getMaxSentenceLength(self):
4469 """
4470 Gets the value of maxSentenceLength or its default value.
4471 """
4472 return self.getOrDefault(self.maxSentenceLength)
4473
4474
4475 @inherit_doc
4476 @ignore_unicode_prefix
4477 class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
4478 """
4479 Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
4480 natural language processing or machine learning process.
4481
4482 >>> sent = ("a b " * 100 + "a c " * 10).split(" ")
4483 >>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])
4484 >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model")
4485 >>> word2Vec.setMaxIter(10)
4486 Word2Vec...
4487 >>> word2Vec.getMaxIter()
4488 10
4489 >>> word2Vec.clear(word2Vec.maxIter)
4490 >>> model = word2Vec.fit(doc)
4491 >>> model.getMinCount()
4492 5
4493 >>> model.setInputCol("sentence")
4494 Word2VecModel...
4495 >>> model.getVectors().show()
4496 +----+--------------------+
4497 |word| vector|
4498 +----+--------------------+
4499 | a|[0.09511678665876...|
4500 | b|[-1.2028766870498...|
4501 | c|[0.30153277516365...|
4502 +----+--------------------+
4503 ...
4504 >>> model.findSynonymsArray("a", 2)
4505 [(u'b', 0.015859870240092278), (u'c', -0.5680795907974243)]
4506 >>> from pyspark.sql.functions import format_number as fmt
4507 >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show()
4508 +----+----------+
4509 |word|similarity|
4510 +----+----------+
4511 | b| 0.01586|
4512 | c| -0.56808|
4513 +----+----------+
4514 ...
4515 >>> model.transform(doc).head().model
4516 DenseVector([-0.4833, 0.1855, -0.273, -0.0509, -0.4769])
4517 >>> word2vecPath = temp_path + "/word2vec"
4518 >>> word2Vec.save(word2vecPath)
4519 >>> loadedWord2Vec = Word2Vec.load(word2vecPath)
4520 >>> loadedWord2Vec.getVectorSize() == word2Vec.getVectorSize()
4521 True
4522 >>> loadedWord2Vec.getNumPartitions() == word2Vec.getNumPartitions()
4523 True
4524 >>> loadedWord2Vec.getMinCount() == word2Vec.getMinCount()
4525 True
4526 >>> modelPath = temp_path + "/word2vec-model"
4527 >>> model.save(modelPath)
4528 >>> loadedModel = Word2VecModel.load(modelPath)
4529 >>> loadedModel.getVectors().first().word == model.getVectors().first().word
4530 True
4531 >>> loadedModel.getVectors().first().vector == model.getVectors().first().vector
4532 True
4533
4534 .. versionadded:: 1.4.0
4535 """
4536
4537 @keyword_only
4538 def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
4539 seed=None, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000):
4540 """
4541 __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
4542 seed=None, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000)
4543 """
4544 super(Word2Vec, self).__init__()
4545 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
4546 self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
4547 windowSize=5, maxSentenceLength=1000)
4548 kwargs = self._input_kwargs
4549 self.setParams(**kwargs)
4550
4551 @keyword_only
4552 @since("1.4.0")
4553 def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
4554 seed=None, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000):
4555 """
4556 setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \
4557 inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000)
4558 Sets params for this Word2Vec.
4559 """
4560 kwargs = self._input_kwargs
4561 return self._set(**kwargs)
4562
4563 @since("1.4.0")
4564 def setVectorSize(self, value):
4565 """
4566 Sets the value of :py:attr:`vectorSize`.
4567 """
4568 return self._set(vectorSize=value)
4569
4570 @since("1.4.0")
4571 def setNumPartitions(self, value):
4572 """
4573 Sets the value of :py:attr:`numPartitions`.
4574 """
4575 return self._set(numPartitions=value)
4576
4577 @since("1.4.0")
4578 def setMinCount(self, value):
4579 """
4580 Sets the value of :py:attr:`minCount`.
4581 """
4582 return self._set(minCount=value)
4583
4584 @since("2.0.0")
4585 def setWindowSize(self, value):
4586 """
4587 Sets the value of :py:attr:`windowSize`.
4588 """
4589 return self._set(windowSize=value)
4590
4591 @since("2.0.0")
4592 def setMaxSentenceLength(self, value):
4593 """
4594 Sets the value of :py:attr:`maxSentenceLength`.
4595 """
4596 return self._set(maxSentenceLength=value)
4597
4598 def setMaxIter(self, value):
4599 """
4600 Sets the value of :py:attr:`maxIter`.
4601 """
4602 return self._set(maxIter=value)
4603
4604 def setInputCol(self, value):
4605 """
4606 Sets the value of :py:attr:`inputCol`.
4607 """
4608 return self._set(inputCol=value)
4609
4610 def setOutputCol(self, value):
4611 """
4612 Sets the value of :py:attr:`outputCol`.
4613 """
4614 return self._set(outputCol=value)
4615
4616 def setSeed(self, value):
4617 """
4618 Sets the value of :py:attr:`seed`.
4619 """
4620 return self._set(seed=value)
4621
4622 @since("1.4.0")
4623 def setStepSize(self, value):
4624 """
4625 Sets the value of :py:attr:`stepSize`.
4626 """
4627 return self._set(stepSize=value)
4628
4629 def _create_model(self, java_model):
4630 return Word2VecModel(java_model)
4631
4632
4633 class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable, JavaMLWritable):
4634 """
4635 Model fitted by :py:class:`Word2Vec`.
4636
4637 .. versionadded:: 1.4.0
4638 """
4639
4640 @since("1.5.0")
4641 def getVectors(self):
4642 """
4643 Returns the vector representation of the words as a dataframe
4644 with two fields, word and vector.
4645 """
4646 return self._call_java("getVectors")
4647
4648 def setInputCol(self, value):
4649 """
4650 Sets the value of :py:attr:`inputCol`.
4651 """
4652 return self._set(inputCol=value)
4653
4654 def setOutputCol(self, value):
4655 """
4656 Sets the value of :py:attr:`outputCol`.
4657 """
4658 return self._set(outputCol=value)
4659
4660 @since("1.5.0")
4661 def findSynonyms(self, word, num):
4662 """
4663 Find "num" number of words closest in similarity to "word".
4664 word can be a string or vector representation.
4665 Returns a dataframe with two fields word and similarity (which
4666 gives the cosine similarity).
4667 """
4668 if not isinstance(word, basestring):
4669 word = _convert_to_vector(word)
4670 return self._call_java("findSynonyms", word, num)
4671
4672 @since("2.3.0")
4673 def findSynonymsArray(self, word, num):
4674 """
4675 Find "num" number of words closest in similarity to "word".
4676 word can be a string or vector representation.
4677 Returns an array with two fields word and similarity (which
4678 gives the cosine similarity).
4679 """
4680 if not isinstance(word, basestring):
4681 word = _convert_to_vector(word)
4682 tuples = self._java_obj.findSynonymsArray(word, num)
4683 return list(map(lambda st: (st._1(), st._2()), list(tuples)))
4684
4685
4686 class _PCAParams(HasInputCol, HasOutputCol):
4687 """
4688 Params for :py:class:`PCA` and :py:class:`PCAModel`.
4689
4690 .. versionadded:: 3.0.0
4691 """
4692
4693 k = Param(Params._dummy(), "k", "the number of principal components",
4694 typeConverter=TypeConverters.toInt)
4695
4696 @since("1.5.0")
4697 def getK(self):
4698 """
4699 Gets the value of k or its default value.
4700 """
4701 return self.getOrDefault(self.k)
4702
4703
4704 @inherit_doc
4705 class PCA(JavaEstimator, _PCAParams, JavaMLReadable, JavaMLWritable):
4706 """
4707 PCA trains a model to project vectors to a lower dimensional space of the
4708 top :py:attr:`k` principal components.
4709
4710 >>> from pyspark.ml.linalg import Vectors
4711 >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),
4712 ... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),
4713 ... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]
4714 >>> df = spark.createDataFrame(data,["features"])
4715 >>> pca = PCA(k=2, inputCol="features")
4716 >>> pca.setOutputCol("pca_features")
4717 PCA...
4718 >>> model = pca.fit(df)
4719 >>> model.getK()
4720 2
4721 >>> model.setOutputCol("output")
4722 PCAModel...
4723 >>> model.transform(df).collect()[0].output
4724 DenseVector([1.648..., -4.013...])
4725 >>> model.explainedVariance
4726 DenseVector([0.794..., 0.205...])
4727 >>> pcaPath = temp_path + "/pca"
4728 >>> pca.save(pcaPath)
4729 >>> loadedPca = PCA.load(pcaPath)
4730 >>> loadedPca.getK() == pca.getK()
4731 True
4732 >>> modelPath = temp_path + "/pca-model"
4733 >>> model.save(modelPath)
4734 >>> loadedModel = PCAModel.load(modelPath)
4735 >>> loadedModel.pc == model.pc
4736 True
4737 >>> loadedModel.explainedVariance == model.explainedVariance
4738 True
4739
4740 .. versionadded:: 1.5.0
4741 """
4742
4743 @keyword_only
4744 def __init__(self, k=None, inputCol=None, outputCol=None):
4745 """
4746 __init__(self, k=None, inputCol=None, outputCol=None)
4747 """
4748 super(PCA, self).__init__()
4749 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid)
4750 kwargs = self._input_kwargs
4751 self.setParams(**kwargs)
4752
4753 @keyword_only
4754 @since("1.5.0")
4755 def setParams(self, k=None, inputCol=None, outputCol=None):
4756 """
4757 setParams(self, k=None, inputCol=None, outputCol=None)
4758 Set params for this PCA.
4759 """
4760 kwargs = self._input_kwargs
4761 return self._set(**kwargs)
4762
4763 @since("1.5.0")
4764 def setK(self, value):
4765 """
4766 Sets the value of :py:attr:`k`.
4767 """
4768 return self._set(k=value)
4769
4770 def setInputCol(self, value):
4771 """
4772 Sets the value of :py:attr:`inputCol`.
4773 """
4774 return self._set(inputCol=value)
4775
4776 def setOutputCol(self, value):
4777 """
4778 Sets the value of :py:attr:`outputCol`.
4779 """
4780 return self._set(outputCol=value)
4781
4782 def _create_model(self, java_model):
4783 return PCAModel(java_model)
4784
4785
4786 class PCAModel(JavaModel, _PCAParams, JavaMLReadable, JavaMLWritable):
4787 """
4788 Model fitted by :py:class:`PCA`. Transforms vectors to a lower dimensional space.
4789
4790 .. versionadded:: 1.5.0
4791 """
4792
4793 @since("3.0.0")
4794 def setInputCol(self, value):
4795 """
4796 Sets the value of :py:attr:`inputCol`.
4797 """
4798 return self._set(inputCol=value)
4799
4800 @since("3.0.0")
4801 def setOutputCol(self, value):
4802 """
4803 Sets the value of :py:attr:`outputCol`.
4804 """
4805 return self._set(outputCol=value)
4806
4807 @property
4808 @since("2.0.0")
4809 def pc(self):
4810 """
4811 Returns a principal components Matrix.
4812 Each column is one principal component.
4813 """
4814 return self._call_java("pc")
4815
4816 @property
4817 @since("2.0.0")
4818 def explainedVariance(self):
4819 """
4820 Returns a vector of proportions of variance
4821 explained by each principal component.
4822 """
4823 return self._call_java("explainedVariance")
4824
4825
4826 class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid):
4827 """
4828 Params for :py:class:`RFormula` and :py:class:`RFormula`.
4829
4830 .. versionadded:: 3.0.0
4831 """
4832
4833 formula = Param(Params._dummy(), "formula", "R model formula",
4834 typeConverter=TypeConverters.toString)
4835
4836 forceIndexLabel = Param(Params._dummy(), "forceIndexLabel",
4837 "Force to index label whether it is numeric or string",
4838 typeConverter=TypeConverters.toBoolean)
4839
4840 stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType",
4841 "How to order categories of a string feature column used by " +
4842 "StringIndexer. The last category after ordering is dropped " +
4843 "when encoding strings. Supported options: frequencyDesc, " +
4844 "frequencyAsc, alphabetDesc, alphabetAsc. The default value " +
4845 "is frequencyDesc. When the ordering is set to alphabetDesc, " +
4846 "RFormula drops the same category as R when encoding strings.",
4847 typeConverter=TypeConverters.toString)
4848
4849 handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
4850 "Options are 'skip' (filter out rows with invalid values), " +
4851 "'error' (throw an error), or 'keep' (put invalid data in a special " +
4852 "additional bucket, at index numLabels).",
4853 typeConverter=TypeConverters.toString)
4854
4855 @since("1.5.0")
4856 def getFormula(self):
4857 """
4858 Gets the value of :py:attr:`formula`.
4859 """
4860 return self.getOrDefault(self.formula)
4861
4862 @since("2.1.0")
4863 def getForceIndexLabel(self):
4864 """
4865 Gets the value of :py:attr:`forceIndexLabel`.
4866 """
4867 return self.getOrDefault(self.forceIndexLabel)
4868
4869 @since("2.3.0")
4870 def getStringIndexerOrderType(self):
4871 """
4872 Gets the value of :py:attr:`stringIndexerOrderType` or its default value 'frequencyDesc'.
4873 """
4874 return self.getOrDefault(self.stringIndexerOrderType)
4875
4876
4877 @inherit_doc
4878 class RFormula(JavaEstimator, _RFormulaParams, JavaMLReadable, JavaMLWritable):
4879 """
4880 Implements the transforms required for fitting a dataset against an
4881 R model formula. Currently we support a limited subset of the R
4882 operators, including '~', '.', ':', '+', '-', '*', and '^'.
4883 Also see the `R formula docs
4884 <http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html>`_.
4885
4886 >>> df = spark.createDataFrame([
4887 ... (1.0, 1.0, "a"),
4888 ... (0.0, 2.0, "b"),
4889 ... (0.0, 0.0, "a")
4890 ... ], ["y", "x", "s"])
4891 >>> rf = RFormula(formula="y ~ x + s")
4892 >>> model = rf.fit(df)
4893 >>> model.getLabelCol()
4894 'label'
4895 >>> model.transform(df).show()
4896 +---+---+---+---------+-----+
4897 | y| x| s| features|label|
4898 +---+---+---+---------+-----+
4899 |1.0|1.0| a|[1.0,1.0]| 1.0|
4900 |0.0|2.0| b|[2.0,0.0]| 0.0|
4901 |0.0|0.0| a|[0.0,1.0]| 0.0|
4902 +---+---+---+---------+-----+
4903 ...
4904 >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()
4905 +---+---+---+--------+-----+
4906 | y| x| s|features|label|
4907 +---+---+---+--------+-----+
4908 |1.0|1.0| a| [1.0]| 1.0|
4909 |0.0|2.0| b| [2.0]| 0.0|
4910 |0.0|0.0| a| [0.0]| 0.0|
4911 +---+---+---+--------+-----+
4912 ...
4913 >>> rFormulaPath = temp_path + "/rFormula"
4914 >>> rf.save(rFormulaPath)
4915 >>> loadedRF = RFormula.load(rFormulaPath)
4916 >>> loadedRF.getFormula() == rf.getFormula()
4917 True
4918 >>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()
4919 True
4920 >>> loadedRF.getLabelCol() == rf.getLabelCol()
4921 True
4922 >>> loadedRF.getHandleInvalid() == rf.getHandleInvalid()
4923 True
4924 >>> str(loadedRF)
4925 'RFormula(y ~ x + s) (uid=...)'
4926 >>> modelPath = temp_path + "/rFormulaModel"
4927 >>> model.save(modelPath)
4928 >>> loadedModel = RFormulaModel.load(modelPath)
4929 >>> loadedModel.uid == model.uid
4930 True
4931 >>> loadedModel.transform(df).show()
4932 +---+---+---+---------+-----+
4933 | y| x| s| features|label|
4934 +---+---+---+---------+-----+
4935 |1.0|1.0| a|[1.0,1.0]| 1.0|
4936 |0.0|2.0| b|[2.0,0.0]| 0.0|
4937 |0.0|0.0| a|[0.0,1.0]| 0.0|
4938 +---+---+---+---------+-----+
4939 ...
4940 >>> str(loadedModel)
4941 'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...)'
4942
4943 .. versionadded:: 1.5.0
4944 """
4945
4946 @keyword_only
4947 def __init__(self, formula=None, featuresCol="features", labelCol="label",
4948 forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
4949 handleInvalid="error"):
4950 """
4951 __init__(self, formula=None, featuresCol="features", labelCol="label", \
4952 forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
4953 handleInvalid="error")
4954 """
4955 super(RFormula, self).__init__()
4956 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
4957 self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
4958 handleInvalid="error")
4959 kwargs = self._input_kwargs
4960 self.setParams(**kwargs)
4961
4962 @keyword_only
4963 @since("1.5.0")
4964 def setParams(self, formula=None, featuresCol="features", labelCol="label",
4965 forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
4966 handleInvalid="error"):
4967 """
4968 setParams(self, formula=None, featuresCol="features", labelCol="label", \
4969 forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
4970 handleInvalid="error")
4971 Sets params for RFormula.
4972 """
4973 kwargs = self._input_kwargs
4974 return self._set(**kwargs)
4975
4976 @since("1.5.0")
4977 def setFormula(self, value):
4978 """
4979 Sets the value of :py:attr:`formula`.
4980 """
4981 return self._set(formula=value)
4982
4983 @since("2.1.0")
4984 def setForceIndexLabel(self, value):
4985 """
4986 Sets the value of :py:attr:`forceIndexLabel`.
4987 """
4988 return self._set(forceIndexLabel=value)
4989
4990 @since("2.3.0")
4991 def setStringIndexerOrderType(self, value):
4992 """
4993 Sets the value of :py:attr:`stringIndexerOrderType`.
4994 """
4995 return self._set(stringIndexerOrderType=value)
4996
4997 def setFeaturesCol(self, value):
4998 """
4999 Sets the value of :py:attr:`featuresCol`.
5000 """
5001 return self._set(featuresCol=value)
5002
5003 def setLabelCol(self, value):
5004 """
5005 Sets the value of :py:attr:`labelCol`.
5006 """
5007 return self._set(labelCol=value)
5008
5009 def setHandleInvalid(self, value):
5010 """
5011 Sets the value of :py:attr:`handleInvalid`.
5012 """
5013 return self._set(handleInvalid=value)
5014
5015 def _create_model(self, java_model):
5016 return RFormulaModel(java_model)
5017
5018 def __str__(self):
5019 formulaStr = self.getFormula() if self.isDefined(self.formula) else ""
5020 return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid)
5021
5022
5023 class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable, JavaMLWritable):
5024 """
5025 Model fitted by :py:class:`RFormula`. Fitting is required to determine the
5026 factor levels of formula terms.
5027
5028 .. versionadded:: 1.5.0
5029 """
5030
5031 def __str__(self):
5032 resolvedFormula = self._call_java("resolvedFormula")
5033 return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid)
5034
5035
5036 class _ChiSqSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
5037 """
5038 Params for :py:class:`ChiSqSelector` and :py:class:`ChiSqSelectorModel`.
5039
5040 .. versionadded:: 3.0.0
5041 """
5042
5043 selectorType = Param(Params._dummy(), "selectorType",
5044 "The selector type of the ChisqSelector. " +
5045 "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.",
5046 typeConverter=TypeConverters.toString)
5047
5048 numTopFeatures = \
5049 Param(Params._dummy(), "numTopFeatures",
5050 "Number of features that selector will select, ordered by ascending p-value. " +
5051 "If the number of features is < numTopFeatures, then this will select " +
5052 "all features.", typeConverter=TypeConverters.toInt)
5053
5054 percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
5055 "will select, ordered by ascending p-value.",
5056 typeConverter=TypeConverters.toFloat)
5057
5058 fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.",
5059 typeConverter=TypeConverters.toFloat)
5060
5061 fdr = Param(Params._dummy(), "fdr", "The upper bound of the expected false discovery rate.",
5062 typeConverter=TypeConverters.toFloat)
5063
5064 fwe = Param(Params._dummy(), "fwe", "The upper bound of the expected family-wise error rate.",
5065 typeConverter=TypeConverters.toFloat)
5066
5067 @since("2.1.0")
5068 def getSelectorType(self):
5069 """
5070 Gets the value of selectorType or its default value.
5071 """
5072 return self.getOrDefault(self.selectorType)
5073
5074 @since("2.0.0")
5075 def getNumTopFeatures(self):
5076 """
5077 Gets the value of numTopFeatures or its default value.
5078 """
5079 return self.getOrDefault(self.numTopFeatures)
5080
5081 @since("2.1.0")
5082 def getPercentile(self):
5083 """
5084 Gets the value of percentile or its default value.
5085 """
5086 return self.getOrDefault(self.percentile)
5087
5088 @since("2.1.0")
5089 def getFpr(self):
5090 """
5091 Gets the value of fpr or its default value.
5092 """
5093 return self.getOrDefault(self.fpr)
5094
5095 @since("2.2.0")
5096 def getFdr(self):
5097 """
5098 Gets the value of fdr or its default value.
5099 """
5100 return self.getOrDefault(self.fdr)
5101
5102 @since("2.2.0")
5103 def getFwe(self):
5104 """
5105 Gets the value of fwe or its default value.
5106 """
5107 return self.getOrDefault(self.fwe)
5108
5109
5110 @inherit_doc
5111 class ChiSqSelector(JavaEstimator, _ChiSqSelectorParams, JavaMLReadable, JavaMLWritable):
5112 """
5113 Chi-Squared feature selection, which selects categorical features to use for predicting a
5114 categorical label.
5115 The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`,
5116 `fdr`, `fwe`.
5117
5118 * `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
5119
5120 * `percentile` is similar but chooses a fraction of all features
5121 instead of a fixed number.
5122
5123 * `fpr` chooses all features whose p-values are below a threshold,
5124 thus controlling the false positive rate of selection.
5125
5126 * `fdr` uses the `Benjamini-Hochberg procedure <https://en.wikipedia.org/wiki/
5127 False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure>`_
5128 to choose all features whose false discovery rate is below a threshold.
5129
5130 * `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by
5131 1/numFeatures, thus controlling the family-wise error rate of selection.
5132
5133 By default, the selection method is `numTopFeatures`, with the default number of top features
5134 set to 50.
5135
5136
5137 >>> from pyspark.ml.linalg import Vectors
5138 >>> df = spark.createDataFrame(
5139 ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
5140 ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
5141 ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],
5142 ... ["features", "label"])
5143 >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
5144 >>> model = selector.fit(df)
5145 >>> model.getFeaturesCol()
5146 'features'
5147 >>> model.setFeaturesCol("features")
5148 ChiSqSelectorModel...
5149 >>> model.transform(df).head().selectedFeatures
5150 DenseVector([18.0])
5151 >>> model.selectedFeatures
5152 [2]
5153 >>> chiSqSelectorPath = temp_path + "/chi-sq-selector"
5154 >>> selector.save(chiSqSelectorPath)
5155 >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath)
5156 >>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()
5157 True
5158 >>> modelPath = temp_path + "/chi-sq-selector-model"
5159 >>> model.save(modelPath)
5160 >>> loadedModel = ChiSqSelectorModel.load(modelPath)
5161 >>> loadedModel.selectedFeatures == model.selectedFeatures
5162 True
5163
5164 .. versionadded:: 2.0.0
5165 """
5166
5167 @keyword_only
5168 def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None,
5169 labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05,
5170 fdr=0.05, fwe=0.05):
5171 """
5172 __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
5173 labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
5174 fdr=0.05, fwe=0.05)
5175 """
5176 super(ChiSqSelector, self).__init__()
5177 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
5178 self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1,
5179 fpr=0.05, fdr=0.05, fwe=0.05)
5180 kwargs = self._input_kwargs
5181 self.setParams(**kwargs)
5182
5183 @keyword_only
5184 @since("2.0.0")
5185 def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
5186 labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05,
5187 fdr=0.05, fwe=0.05):
5188 """
5189 setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
5190 labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
5191 fdr=0.05, fwe=0.05)
5192 Sets params for this ChiSqSelector.
5193 """
5194 kwargs = self._input_kwargs
5195 return self._set(**kwargs)
5196
5197 @since("2.1.0")
5198 def setSelectorType(self, value):
5199 """
5200 Sets the value of :py:attr:`selectorType`.
5201 """
5202 return self._set(selectorType=value)
5203
5204 @since("2.0.0")
5205 def setNumTopFeatures(self, value):
5206 """
5207 Sets the value of :py:attr:`numTopFeatures`.
5208 Only applicable when selectorType = "numTopFeatures".
5209 """
5210 return self._set(numTopFeatures=value)
5211
5212 @since("2.1.0")
5213 def setPercentile(self, value):
5214 """
5215 Sets the value of :py:attr:`percentile`.
5216 Only applicable when selectorType = "percentile".
5217 """
5218 return self._set(percentile=value)
5219
5220 @since("2.1.0")
5221 def setFpr(self, value):
5222 """
5223 Sets the value of :py:attr:`fpr`.
5224 Only applicable when selectorType = "fpr".
5225 """
5226 return self._set(fpr=value)
5227
5228 @since("2.2.0")
5229 def setFdr(self, value):
5230 """
5231 Sets the value of :py:attr:`fdr`.
5232 Only applicable when selectorType = "fdr".
5233 """
5234 return self._set(fdr=value)
5235
5236 @since("2.2.0")
5237 def setFwe(self, value):
5238 """
5239 Sets the value of :py:attr:`fwe`.
5240 Only applicable when selectorType = "fwe".
5241 """
5242 return self._set(fwe=value)
5243
5244 def setFeaturesCol(self, value):
5245 """
5246 Sets the value of :py:attr:`featuresCol`.
5247 """
5248 return self._set(featuresCol=value)
5249
5250 def setOutputCol(self, value):
5251 """
5252 Sets the value of :py:attr:`outputCol`.
5253 """
5254 return self._set(outputCol=value)
5255
5256 def setLabelCol(self, value):
5257 """
5258 Sets the value of :py:attr:`labelCol`.
5259 """
5260 return self._set(labelCol=value)
5261
5262 def _create_model(self, java_model):
5263 return ChiSqSelectorModel(java_model)
5264
5265
5266 class ChiSqSelectorModel(JavaModel, _ChiSqSelectorParams, JavaMLReadable, JavaMLWritable):
5267 """
5268 Model fitted by :py:class:`ChiSqSelector`.
5269
5270 .. versionadded:: 2.0.0
5271 """
5272
5273 @since("3.0.0")
5274 def setFeaturesCol(self, value):
5275 """
5276 Sets the value of :py:attr:`featuresCol`.
5277 """
5278 return self._set(featuresCol=value)
5279
5280 @since("3.0.0")
5281 def setOutputCol(self, value):
5282 """
5283 Sets the value of :py:attr:`outputCol`.
5284 """
5285 return self._set(outputCol=value)
5286
5287 @property
5288 @since("2.0.0")
5289 def selectedFeatures(self):
5290 """
5291 List of indices to select (filter).
5292 """
5293 return self._call_java("selectedFeatures")
5294
5295
5296 @inherit_doc
5297 class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable,
5298 JavaMLWritable):
5299 """
5300 A feature transformer that adds size information to the metadata of a vector column.
5301 VectorAssembler needs size information for its input columns and cannot be used on streaming
5302 dataframes without this metadata.
5303
5304 .. note:: VectorSizeHint modifies `inputCol` to include size metadata and does not have an
5305 outputCol.
5306
5307 >>> from pyspark.ml.linalg import Vectors
5308 >>> from pyspark.ml import Pipeline, PipelineModel
5309 >>> data = [(Vectors.dense([1., 2., 3.]), 4.)]
5310 >>> df = spark.createDataFrame(data, ["vector", "float"])
5311 >>>
5312 >>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip")
5313 >>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol="assembled")
5314 >>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])
5315 >>>
5316 >>> pipelineModel = pipeline.fit(df)
5317 >>> pipelineModel.transform(df).head().assembled
5318 DenseVector([1.0, 2.0, 3.0, 4.0])
5319 >>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline"
5320 >>> pipelineModel.save(vectorSizeHintPath)
5321 >>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)
5322 >>> loaded = loadedPipeline.transform(df).head().assembled
5323 >>> expected = pipelineModel.transform(df).head().assembled
5324 >>> loaded == expected
5325 True
5326
5327 .. versionadded:: 2.3.0
5328 """
5329
5330 size = Param(Params._dummy(), "size", "Size of vectors in column.",
5331 typeConverter=TypeConverters.toInt)
5332
5333 handleInvalid = Param(Params._dummy(), "handleInvalid",
5334 "How to handle invalid vectors in inputCol. Invalid vectors include "
5335 "nulls and vectors with the wrong size. The options are `skip` (filter "
5336 "out rows with invalid vectors), `error` (throw an error) and "
5337 "`optimistic` (do not check the vector size, and keep all rows). "
5338 "`error` by default.",
5339 TypeConverters.toString)
5340
5341 @keyword_only
5342 def __init__(self, inputCol=None, size=None, handleInvalid="error"):
5343 """
5344 __init__(self, inputCol=None, size=None, handleInvalid="error")
5345 """
5346 super(VectorSizeHint, self).__init__()
5347 self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid)
5348 self._setDefault(handleInvalid="error")
5349 self.setParams(**self._input_kwargs)
5350
5351 @keyword_only
5352 @since("2.3.0")
5353 def setParams(self, inputCol=None, size=None, handleInvalid="error"):
5354 """
5355 setParams(self, inputCol=None, size=None, handleInvalid="error")
5356 Sets params for this VectorSizeHint.
5357 """
5358 kwargs = self._input_kwargs
5359 return self._set(**kwargs)
5360
5361 @since("2.3.0")
5362 def getSize(self):
5363 """ Gets size param, the size of vectors in `inputCol`."""
5364 return self.getOrDefault(self.size)
5365
5366 @since("2.3.0")
5367 def setSize(self, value):
5368 """ Sets size param, the size of vectors in `inputCol`."""
5369 return self._set(size=value)
5370
5371 def setInputCol(self, value):
5372 """
5373 Sets the value of :py:attr:`inputCol`.
5374 """
5375 return self._set(inputCol=value)
5376
5377 def setHandleInvalid(self, value):
5378 """
5379 Sets the value of :py:attr:`handleInvalid`.
5380 """
5381 return self._set(handleInvalid=value)
5382
5383
5384 if __name__ == "__main__":
5385 import doctest
5386 import tempfile
5387
5388 import pyspark.ml.feature
5389 from pyspark.sql import Row, SparkSession
5390
5391 globs = globals().copy()
5392 features = pyspark.ml.feature.__dict__.copy()
5393 globs.update(features)
5394
5395
5396
5397 spark = SparkSession.builder\
5398 .master("local[2]")\
5399 .appName("ml.feature tests")\
5400 .getOrCreate()
5401 sc = spark.sparkContext
5402 globs['sc'] = sc
5403 globs['spark'] = spark
5404 testData = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="b"),
5405 Row(id=2, label="c"), Row(id=3, label="a"),
5406 Row(id=4, label="a"), Row(id=5, label="c")], 2)
5407 globs['stringIndDf'] = spark.createDataFrame(testData)
5408 temp_path = tempfile.mkdtemp()
5409 globs['temp_path'] = temp_path
5410 try:
5411 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
5412 spark.stop()
5413 finally:
5414 from shutil import rmtree
5415 try:
5416 rmtree(temp_path)
5417 except OSError:
5418 pass
5419 if failure_count:
5420 sys.exit(-1)