0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.recommendation;
0019
0020 import java.util.ArrayList;
0021 import java.util.List;
0022
0023 import scala.Tuple2;
0024 import scala.Tuple3;
0025
0026 import org.junit.Assert;
0027 import org.junit.Test;
0028
0029 import org.apache.spark.SharedSparkSession;
0030 import org.apache.spark.api.java.JavaPairRDD;
0031 import org.apache.spark.api.java.JavaRDD;
0032
0033 public class JavaALSSuite extends SharedSparkSession {
0034
0035 private void validatePrediction(
0036 MatrixFactorizationModel model,
0037 int users,
0038 int products,
0039 double[] trueRatings,
0040 double matchThreshold,
0041 boolean implicitPrefs,
0042 double[] truePrefs) {
0043 List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
0044 for (int u = 0; u < users; ++u) {
0045 for (int p = 0; p < products; ++p) {
0046 localUsersProducts.add(new Tuple2<>(u, p));
0047 }
0048 }
0049 JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
0050 List<Rating> predictedRatings = model.predict(usersProducts).collect();
0051 Assert.assertEquals(users * products, predictedRatings.size());
0052 if (!implicitPrefs) {
0053 for (Rating r : predictedRatings) {
0054 double prediction = r.rating();
0055 double correct = trueRatings[r.product() * users + r.user()];
0056 Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
0057 prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
0058 }
0059 } else {
0060
0061
0062 double sqErr = 0.0;
0063 double denom = 0.0;
0064 for (Rating r : predictedRatings) {
0065 double prediction = r.rating();
0066 double truePref = truePrefs[r.product() * users + r.user()];
0067 double confidence = 1.0 +
0068 Math.abs(trueRatings[r.product() * users + r.user()]);
0069 double err = confidence * (truePref - prediction) * (truePref - prediction);
0070 sqErr += err;
0071 denom += confidence;
0072 }
0073 double rmse = Math.sqrt(sqErr / denom);
0074 Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
0075 rmse, matchThreshold), rmse < matchThreshold);
0076 }
0077 }
0078
0079 @Test
0080 public void runALSUsingStaticMethods() {
0081 int features = 1;
0082 int iterations = 15;
0083 int users = 50;
0084 int products = 100;
0085 Tuple3<List<Rating>, double[], double[]> testData =
0086 ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
0087
0088 JavaRDD<Rating> data = jsc.parallelize(testData._1());
0089 MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
0090 validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
0091 }
0092
0093 @Test
0094 public void runALSUsingConstructor() {
0095 int features = 2;
0096 int iterations = 15;
0097 int users = 100;
0098 int products = 200;
0099 Tuple3<List<Rating>, double[], double[]> testData =
0100 ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
0101
0102 JavaRDD<Rating> data = jsc.parallelize(testData._1());
0103
0104 MatrixFactorizationModel model = new ALS().setRank(features)
0105 .setIterations(iterations)
0106 .run(data);
0107 validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
0108 }
0109
0110 @Test
0111 public void runImplicitALSUsingStaticMethods() {
0112 int features = 1;
0113 int iterations = 15;
0114 int users = 80;
0115 int products = 160;
0116 Tuple3<List<Rating>, double[], double[]> testData =
0117 ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
0118
0119 JavaRDD<Rating> data = jsc.parallelize(testData._1());
0120 MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
0121 validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
0122 }
0123
0124 @Test
0125 public void runImplicitALSUsingConstructor() {
0126 int features = 2;
0127 int iterations = 15;
0128 int users = 100;
0129 int products = 200;
0130 Tuple3<List<Rating>, double[], double[]> testData =
0131 ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
0132
0133 JavaRDD<Rating> data = jsc.parallelize(testData._1());
0134
0135 MatrixFactorizationModel model = new ALS().setRank(features)
0136 .setIterations(iterations)
0137 .setImplicitPrefs(true)
0138 .run(data.rdd());
0139 validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
0140 }
0141
0142 @Test
0143 public void runImplicitALSWithNegativeWeight() {
0144 int features = 2;
0145 int iterations = 15;
0146 int users = 80;
0147 int products = 160;
0148 Tuple3<List<Rating>, double[], double[]> testData =
0149 ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
0150
0151 JavaRDD<Rating> data = jsc.parallelize(testData._1());
0152 MatrixFactorizationModel model = new ALS().setRank(features)
0153 .setIterations(iterations)
0154 .setImplicitPrefs(true)
0155 .setSeed(8675309L)
0156 .run(data.rdd());
0157 validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
0158 }
0159
0160 @Test
0161 public void runRecommend() {
0162 int features = 5;
0163 int iterations = 10;
0164 int users = 200;
0165 int products = 50;
0166 List<Rating> testData = ALSSuite.generateRatingsAsJava(
0167 users, products, features, 0.7, true, false)._1();
0168 JavaRDD<Rating> data = jsc.parallelize(testData);
0169 MatrixFactorizationModel model = new ALS().setRank(features)
0170 .setIterations(iterations)
0171 .setImplicitPrefs(true)
0172 .setSeed(8675309L)
0173 .run(data.rdd());
0174 validateRecommendations(model.recommendProducts(1, 10), 10);
0175 validateRecommendations(model.recommendUsers(1, 20), 20);
0176 }
0177
0178 private static void validateRecommendations(Rating[] recommendations, int howMany) {
0179 Assert.assertEquals(howMany, recommendations.length);
0180 for (int i = 1; i < recommendations.length; i++) {
0181 Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
0182 }
0183 Assert.assertTrue(recommendations[0].rating() > 0.7);
0184 }
0185
0186 }