0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019
0020 from pyspark import since, keyword_only
0021 from pyspark.ml.util import *
0022 from pyspark.ml.wrapper import JavaEstimator, JavaModel
0023 from pyspark.ml.param.shared import *
0024 from pyspark.ml.common import inherit_doc
0025
0026
0027 __all__ = ['ALS', 'ALSModel']
0028
0029
0030 @inherit_doc
0031 class _ALSModelParams(HasPredictionCol, HasBlockSize):
0032 """
0033 Params for :py:class:`ALS` and :py:class:`ALSModel`.
0034
0035 .. versionadded:: 3.0.0
0036 """
0037
0038 userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
0039 "the integer value range.", typeConverter=TypeConverters.toString)
0040 itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
0041 "the integer value range.", typeConverter=TypeConverters.toString)
0042 coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
0043 "unknown or new users/items at prediction time. This may be useful " +
0044 "in cross-validation or production scenarios, for handling " +
0045 "user/item ids the model has not seen in the training data. " +
0046 "Supported values: 'nan', 'drop'.",
0047 typeConverter=TypeConverters.toString)
0048
0049 @since("1.4.0")
0050 def getUserCol(self):
0051 """
0052 Gets the value of userCol or its default value.
0053 """
0054 return self.getOrDefault(self.userCol)
0055
0056 @since("1.4.0")
0057 def getItemCol(self):
0058 """
0059 Gets the value of itemCol or its default value.
0060 """
0061 return self.getOrDefault(self.itemCol)
0062
0063 @since("2.2.0")
0064 def getColdStartStrategy(self):
0065 """
0066 Gets the value of coldStartStrategy or its default value.
0067 """
0068 return self.getOrDefault(self.coldStartStrategy)
0069
0070
0071 @inherit_doc
0072 class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed):
0073 """
0074 Params for :py:class:`ALS`.
0075
0076 .. versionadded:: 3.0.0
0077 """
0078
0079 rank = Param(Params._dummy(), "rank", "rank of the factorization",
0080 typeConverter=TypeConverters.toInt)
0081 numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks",
0082 typeConverter=TypeConverters.toInt)
0083 numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks",
0084 typeConverter=TypeConverters.toInt)
0085 implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference",
0086 typeConverter=TypeConverters.toBoolean)
0087 alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
0088 typeConverter=TypeConverters.toFloat)
0089
0090 ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
0091 typeConverter=TypeConverters.toString)
0092 nonnegative = Param(Params._dummy(), "nonnegative",
0093 "whether to use nonnegative constraint for least squares",
0094 typeConverter=TypeConverters.toBoolean)
0095 intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel",
0096 "StorageLevel for intermediate datasets. Cannot be 'NONE'.",
0097 typeConverter=TypeConverters.toString)
0098 finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
0099 "StorageLevel for ALS model factors.",
0100 typeConverter=TypeConverters.toString)
0101
0102 @since("1.4.0")
0103 def getRank(self):
0104 """
0105 Gets the value of rank or its default value.
0106 """
0107 return self.getOrDefault(self.rank)
0108
0109 @since("1.4.0")
0110 def getNumUserBlocks(self):
0111 """
0112 Gets the value of numUserBlocks or its default value.
0113 """
0114 return self.getOrDefault(self.numUserBlocks)
0115
0116 @since("1.4.0")
0117 def getNumItemBlocks(self):
0118 """
0119 Gets the value of numItemBlocks or its default value.
0120 """
0121 return self.getOrDefault(self.numItemBlocks)
0122
0123 @since("1.4.0")
0124 def getImplicitPrefs(self):
0125 """
0126 Gets the value of implicitPrefs or its default value.
0127 """
0128 return self.getOrDefault(self.implicitPrefs)
0129
0130 @since("1.4.0")
0131 def getAlpha(self):
0132 """
0133 Gets the value of alpha or its default value.
0134 """
0135 return self.getOrDefault(self.alpha)
0136
0137 @since("1.4.0")
0138 def getRatingCol(self):
0139 """
0140 Gets the value of ratingCol or its default value.
0141 """
0142 return self.getOrDefault(self.ratingCol)
0143
0144 @since("1.4.0")
0145 def getNonnegative(self):
0146 """
0147 Gets the value of nonnegative or its default value.
0148 """
0149 return self.getOrDefault(self.nonnegative)
0150
0151 @since("2.0.0")
0152 def getIntermediateStorageLevel(self):
0153 """
0154 Gets the value of intermediateStorageLevel or its default value.
0155 """
0156 return self.getOrDefault(self.intermediateStorageLevel)
0157
0158 @since("2.0.0")
0159 def getFinalStorageLevel(self):
0160 """
0161 Gets the value of finalStorageLevel or its default value.
0162 """
0163 return self.getOrDefault(self.finalStorageLevel)
0164
0165
0166 @inherit_doc
0167 class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
0168 """
0169 Alternating Least Squares (ALS) matrix factorization.
0170
0171 ALS attempts to estimate the ratings matrix `R` as the product of
0172 two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically
0173 these approximations are called 'factor' matrices. The general
0174 approach is iterative. During each iteration, one of the factor
0175 matrices is held constant, while the other is solved for using least
0176 squares. The newly-solved factor matrix is then held constant while
0177 solving for the other factor matrix.
0178
0179 This is a blocked implementation of the ALS factorization algorithm
0180 that groups the two sets of factors (referred to as "users" and
0181 "products") into blocks and reduces communication by only sending
0182 one copy of each user vector to each product block on each
0183 iteration, and only for the product blocks that need that user's
0184 feature vector. This is achieved by pre-computing some information
0185 about the ratings matrix to determine the "out-links" of each user
0186 (which blocks of products it will contribute to) and "in-link"
0187 information for each product (which of the feature vectors it
0188 receives from each user block it will depend on). This allows us to
0189 send only an array of feature vectors between each user block and
0190 product block, and have the product block find the users' ratings
0191 and update the products based on these messages.
0192
0193 For implicit preference data, the algorithm used is based on
0194 `"Collaborative Filtering for Implicit Feedback Datasets",
0195 <https://doi.org/10.1109/ICDM.2008.22>`_, adapted for the blocked
0196 approach used here.
0197
0198 Essentially instead of finding the low-rank approximations to the
0199 rating matrix `R`, this finds the approximations for a preference
0200 matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0.
0201 The ratings then act as 'confidence' values related to strength of
0202 indicated user preferences rather than explicit ratings given to
0203 items.
0204
0205 .. note:: the input rating dataframe to the ALS implementation should be deterministic.
0206 Nondeterministic data can cause failure during fitting ALS model.
0207 For example, an order-sensitive operation like sampling after a repartition makes
0208 dataframe output nondeterministic, like `df.repartition(2).sample(False, 0.5, 1618)`.
0209 Checkpointing sampled dataframe or adding a sort before sampling can help make the
0210 dataframe deterministic.
0211
0212 >>> df = spark.createDataFrame(
0213 ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
0214 ... ["user", "item", "rating"])
0215 >>> als = ALS(rank=10, seed=0)
0216 >>> als.setMaxIter(5)
0217 ALS...
0218 >>> als.getMaxIter()
0219 5
0220 >>> als.setRegParam(0.1)
0221 ALS...
0222 >>> als.getRegParam()
0223 0.1
0224 >>> als.clear(als.regParam)
0225 >>> model = als.fit(df)
0226 >>> model.getBlockSize()
0227 4096
0228 >>> model.getUserCol()
0229 'user'
0230 >>> model.setUserCol("user")
0231 ALSModel...
0232 >>> model.getItemCol()
0233 'item'
0234 >>> model.setPredictionCol("newPrediction")
0235 ALS...
0236 >>> model.rank
0237 10
0238 >>> model.userFactors.orderBy("id").collect()
0239 [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
0240 >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
0241 >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
0242 >>> predictions[0]
0243 Row(user=0, item=2, newPrediction=0.6929101347923279)
0244 >>> predictions[1]
0245 Row(user=1, item=0, newPrediction=3.47356915473938)
0246 >>> predictions[2]
0247 Row(user=2, item=0, newPrediction=-0.8991986513137817)
0248 >>> user_recs = model.recommendForAllUsers(3)
0249 >>> user_recs.where(user_recs.user == 0)\
0250 .select("recommendations.item", "recommendations.rating").collect()
0251 [Row(item=[0, 1, 2], rating=[3.910..., 1.997..., 0.692...])]
0252 >>> item_recs = model.recommendForAllItems(3)
0253 >>> item_recs.where(item_recs.item == 2)\
0254 .select("recommendations.user", "recommendations.rating").collect()
0255 [Row(user=[2, 1, 0], rating=[4.892..., 3.991..., 0.692...])]
0256 >>> user_subset = df.where(df.user == 2)
0257 >>> user_subset_recs = model.recommendForUserSubset(user_subset, 3)
0258 >>> user_subset_recs.select("recommendations.item", "recommendations.rating").first()
0259 Row(item=[2, 1, 0], rating=[4.892..., 1.076..., -0.899...])
0260 >>> item_subset = df.where(df.item == 0)
0261 >>> item_subset_recs = model.recommendForItemSubset(item_subset, 3)
0262 >>> item_subset_recs.select("recommendations.user", "recommendations.rating").first()
0263 Row(user=[0, 1, 2], rating=[3.910..., 3.473..., -0.899...])
0264 >>> als_path = temp_path + "/als"
0265 >>> als.save(als_path)
0266 >>> als2 = ALS.load(als_path)
0267 >>> als.getMaxIter()
0268 5
0269 >>> model_path = temp_path + "/als_model"
0270 >>> model.save(model_path)
0271 >>> model2 = ALSModel.load(model_path)
0272 >>> model.rank == model2.rank
0273 True
0274 >>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())
0275 True
0276 >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
0277 True
0278
0279 .. versionadded:: 1.4.0
0280 """
0281
0282 @keyword_only
0283 def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
0284 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
0285 ratingCol="rating", nonnegative=False, checkpointInterval=10,
0286 intermediateStorageLevel="MEMORY_AND_DISK",
0287 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
0288 """
0289 __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
0290 implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
0291 ratingCol="rating", nonnegative=false, checkpointInterval=10, \
0292 intermediateStorageLevel="MEMORY_AND_DISK", \
0293 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
0294 """
0295 super(ALS, self).__init__()
0296 self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
0297 self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
0298 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
0299 ratingCol="rating", nonnegative=False, checkpointInterval=10,
0300 intermediateStorageLevel="MEMORY_AND_DISK",
0301 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan",
0302 blockSize=4096)
0303 kwargs = self._input_kwargs
0304 self.setParams(**kwargs)
0305
0306 @keyword_only
0307 @since("1.4.0")
0308 def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
0309 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
0310 ratingCol="rating", nonnegative=False, checkpointInterval=10,
0311 intermediateStorageLevel="MEMORY_AND_DISK",
0312 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
0313 """
0314 setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
0315 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
0316 ratingCol="rating", nonnegative=False, checkpointInterval=10, \
0317 intermediateStorageLevel="MEMORY_AND_DISK", \
0318 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
0319 Sets params for ALS.
0320 """
0321 kwargs = self._input_kwargs
0322 return self._set(**kwargs)
0323
0324 def _create_model(self, java_model):
0325 return ALSModel(java_model)
0326
0327 @since("1.4.0")
0328 def setRank(self, value):
0329 """
0330 Sets the value of :py:attr:`rank`.
0331 """
0332 return self._set(rank=value)
0333
0334 @since("1.4.0")
0335 def setNumUserBlocks(self, value):
0336 """
0337 Sets the value of :py:attr:`numUserBlocks`.
0338 """
0339 return self._set(numUserBlocks=value)
0340
0341 @since("1.4.0")
0342 def setNumItemBlocks(self, value):
0343 """
0344 Sets the value of :py:attr:`numItemBlocks`.
0345 """
0346 return self._set(numItemBlocks=value)
0347
0348 @since("1.4.0")
0349 def setNumBlocks(self, value):
0350 """
0351 Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
0352 """
0353 self._set(numUserBlocks=value)
0354 return self._set(numItemBlocks=value)
0355
0356 @since("1.4.0")
0357 def setImplicitPrefs(self, value):
0358 """
0359 Sets the value of :py:attr:`implicitPrefs`.
0360 """
0361 return self._set(implicitPrefs=value)
0362
0363 @since("1.4.0")
0364 def setAlpha(self, value):
0365 """
0366 Sets the value of :py:attr:`alpha`.
0367 """
0368 return self._set(alpha=value)
0369
0370 @since("1.4.0")
0371 def setUserCol(self, value):
0372 """
0373 Sets the value of :py:attr:`userCol`.
0374 """
0375 return self._set(userCol=value)
0376
0377 @since("1.4.0")
0378 def setItemCol(self, value):
0379 """
0380 Sets the value of :py:attr:`itemCol`.
0381 """
0382 return self._set(itemCol=value)
0383
0384 @since("1.4.0")
0385 def setRatingCol(self, value):
0386 """
0387 Sets the value of :py:attr:`ratingCol`.
0388 """
0389 return self._set(ratingCol=value)
0390
0391 @since("1.4.0")
0392 def setNonnegative(self, value):
0393 """
0394 Sets the value of :py:attr:`nonnegative`.
0395 """
0396 return self._set(nonnegative=value)
0397
0398 @since("2.0.0")
0399 def setIntermediateStorageLevel(self, value):
0400 """
0401 Sets the value of :py:attr:`intermediateStorageLevel`.
0402 """
0403 return self._set(intermediateStorageLevel=value)
0404
0405 @since("2.0.0")
0406 def setFinalStorageLevel(self, value):
0407 """
0408 Sets the value of :py:attr:`finalStorageLevel`.
0409 """
0410 return self._set(finalStorageLevel=value)
0411
0412 @since("2.2.0")
0413 def setColdStartStrategy(self, value):
0414 """
0415 Sets the value of :py:attr:`coldStartStrategy`.
0416 """
0417 return self._set(coldStartStrategy=value)
0418
0419 def setMaxIter(self, value):
0420 """
0421 Sets the value of :py:attr:`maxIter`.
0422 """
0423 return self._set(maxIter=value)
0424
0425 def setRegParam(self, value):
0426 """
0427 Sets the value of :py:attr:`regParam`.
0428 """
0429 return self._set(regParam=value)
0430
0431 def setPredictionCol(self, value):
0432 """
0433 Sets the value of :py:attr:`predictionCol`.
0434 """
0435 return self._set(predictionCol=value)
0436
0437 def setCheckpointInterval(self, value):
0438 """
0439 Sets the value of :py:attr:`checkpointInterval`.
0440 """
0441 return self._set(checkpointInterval=value)
0442
0443 def setSeed(self, value):
0444 """
0445 Sets the value of :py:attr:`seed`.
0446 """
0447 return self._set(seed=value)
0448
0449 @since("3.0.0")
0450 def setBlockSize(self, value):
0451 """
0452 Sets the value of :py:attr:`blockSize`.
0453 """
0454 return self._set(blockSize=value)
0455
0456
0457 class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable):
0458 """
0459 Model fitted by ALS.
0460
0461 .. versionadded:: 1.4.0
0462 """
0463
0464 @since("3.0.0")
0465 def setUserCol(self, value):
0466 """
0467 Sets the value of :py:attr:`userCol`.
0468 """
0469 return self._set(userCol=value)
0470
0471 @since("3.0.0")
0472 def setItemCol(self, value):
0473 """
0474 Sets the value of :py:attr:`itemCol`.
0475 """
0476 return self._set(itemCol=value)
0477
0478 @since("3.0.0")
0479 def setColdStartStrategy(self, value):
0480 """
0481 Sets the value of :py:attr:`coldStartStrategy`.
0482 """
0483 return self._set(coldStartStrategy=value)
0484
0485 @since("3.0.0")
0486 def setPredictionCol(self, value):
0487 """
0488 Sets the value of :py:attr:`predictionCol`.
0489 """
0490 return self._set(predictionCol=value)
0491
0492 @since("3.0.0")
0493 def setBlockSize(self, value):
0494 """
0495 Sets the value of :py:attr:`blockSize`.
0496 """
0497 return self._set(blockSize=value)
0498
0499 @property
0500 @since("1.4.0")
0501 def rank(self):
0502 """rank of the matrix factorization model"""
0503 return self._call_java("rank")
0504
0505 @property
0506 @since("1.4.0")
0507 def userFactors(self):
0508 """
0509 a DataFrame that stores user factors in two columns: `id` and
0510 `features`
0511 """
0512 return self._call_java("userFactors")
0513
0514 @property
0515 @since("1.4.0")
0516 def itemFactors(self):
0517 """
0518 a DataFrame that stores item factors in two columns: `id` and
0519 `features`
0520 """
0521 return self._call_java("itemFactors")
0522
0523 @since("2.2.0")
0524 def recommendForAllUsers(self, numItems):
0525 """
0526 Returns top `numItems` items recommended for each user, for all users.
0527
0528 :param numItems: max number of recommendations for each user
0529 :return: a DataFrame of (userCol, recommendations), where recommendations are
0530 stored as an array of (itemCol, rating) Rows.
0531 """
0532 return self._call_java("recommendForAllUsers", numItems)
0533
0534 @since("2.2.0")
0535 def recommendForAllItems(self, numUsers):
0536 """
0537 Returns top `numUsers` users recommended for each item, for all items.
0538
0539 :param numUsers: max number of recommendations for each item
0540 :return: a DataFrame of (itemCol, recommendations), where recommendations are
0541 stored as an array of (userCol, rating) Rows.
0542 """
0543 return self._call_java("recommendForAllItems", numUsers)
0544
0545 @since("2.3.0")
0546 def recommendForUserSubset(self, dataset, numItems):
0547 """
0548 Returns top `numItems` items recommended for each user id in the input data set. Note that
0549 if there are duplicate ids in the input dataset, only one set of recommendations per unique
0550 id will be returned.
0551
0552 :param dataset: a Dataset containing a column of user ids. The column name must match
0553 `userCol`.
0554 :param numItems: max number of recommendations for each user
0555 :return: a DataFrame of (userCol, recommendations), where recommendations are
0556 stored as an array of (itemCol, rating) Rows.
0557 """
0558 return self._call_java("recommendForUserSubset", dataset, numItems)
0559
0560 @since("2.3.0")
0561 def recommendForItemSubset(self, dataset, numUsers):
0562 """
0563 Returns top `numUsers` users recommended for each item id in the input data set. Note that
0564 if there are duplicate ids in the input dataset, only one set of recommendations per unique
0565 id will be returned.
0566
0567 :param dataset: a Dataset containing a column of item ids. The column name must match
0568 `itemCol`.
0569 :param numUsers: max number of recommendations for each item
0570 :return: a DataFrame of (itemCol, recommendations), where recommendations are
0571 stored as an array of (userCol, rating) Rows.
0572 """
0573 return self._call_java("recommendForItemSubset", dataset, numUsers)
0574
0575
0576 if __name__ == "__main__":
0577 import doctest
0578 import pyspark.ml.recommendation
0579 from pyspark.sql import SparkSession
0580 globs = pyspark.ml.recommendation.__dict__.copy()
0581
0582
0583 spark = SparkSession.builder\
0584 .master("local[2]")\
0585 .appName("ml.recommendation tests")\
0586 .getOrCreate()
0587 sc = spark.sparkContext
0588 globs['sc'] = sc
0589 globs['spark'] = spark
0590 import tempfile
0591 temp_path = tempfile.mkdtemp()
0592 globs['temp_path'] = temp_path
0593 try:
0594 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0595 spark.stop()
0596 finally:
0597 from shutil import rmtree
0598 try:
0599 rmtree(temp_path)
0600 except OSError:
0601 pass
0602 if failure_count:
0603 sys.exit(-1)