0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 This is an example implementation of ALS for learning how to use Spark. Please refer to
0020 pyspark.ml.recommendation.ALS for more conventional use.
0021
0022 This example requires numpy (http://www.numpy.org/)
0023 """
0024 from __future__ import print_function
0025
0026 import sys
0027
0028 import numpy as np
0029 from numpy.random import rand
0030 from numpy import matrix
0031 from pyspark.sql import SparkSession
0032
0033 LAMBDA = 0.01
0034 np.random.seed(42)
0035
0036
0037 def rmse(R, ms, us):
0038 diff = R - ms * us.T
0039 return np.sqrt(np.sum(np.power(diff, 2)) / (M * U))
0040
0041
0042 def update(i, mat, ratings):
0043 uu = mat.shape[0]
0044 ff = mat.shape[1]
0045
0046 XtX = mat.T * mat
0047 Xty = mat.T * ratings[i, :].T
0048
0049 for j in range(ff):
0050 XtX[j, j] += LAMBDA * uu
0051
0052 return np.linalg.solve(XtX, Xty)
0053
0054
0055 if __name__ == "__main__":
0056
0057 """
0058 Usage: als [M] [U] [F] [iterations] [partitions]"
0059 """
0060
0061 print("""WARN: This is a naive implementation of ALS and is given as an
0062 example. Please use pyspark.ml.recommendation.ALS for more
0063 conventional use.""", file=sys.stderr)
0064
0065 spark = SparkSession\
0066 .builder\
0067 .appName("PythonALS")\
0068 .getOrCreate()
0069
0070 sc = spark.sparkContext
0071
0072 M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
0073 U = int(sys.argv[2]) if len(sys.argv) > 2 else 500
0074 F = int(sys.argv[3]) if len(sys.argv) > 3 else 10
0075 ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5
0076 partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2
0077
0078 print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" %
0079 (M, U, F, ITERATIONS, partitions))
0080
0081 R = matrix(rand(M, F)) * matrix(rand(U, F).T)
0082 ms = matrix(rand(M, F))
0083 us = matrix(rand(U, F))
0084
0085 Rb = sc.broadcast(R)
0086 msb = sc.broadcast(ms)
0087 usb = sc.broadcast(us)
0088
0089 for i in range(ITERATIONS):
0090 ms = sc.parallelize(range(M), partitions) \
0091 .map(lambda x: update(x, usb.value, Rb.value)) \
0092 .collect()
0093
0094
0095 ms = matrix(np.array(ms)[:, :, 0])
0096 msb = sc.broadcast(ms)
0097
0098 us = sc.parallelize(range(U), partitions) \
0099 .map(lambda x: update(x, msb.value, Rb.value.T)) \
0100 .collect()
0101 us = matrix(np.array(us)[:, :, 0])
0102 usb = sc.broadcast(us)
0103
0104 error = rmse(R, ms, us)
0105 print("Iteration %d:" % i)
0106 print("\nRMSE: %5.4f\n" % error)
0107
0108 spark.stop()