Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.recommendation;
0019 
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 
0023 import scala.Tuple2;
0024 import scala.Tuple3;
0025 
0026 import org.junit.Assert;
0027 import org.junit.Test;
0028 
0029 import org.apache.spark.SharedSparkSession;
0030 import org.apache.spark.api.java.JavaPairRDD;
0031 import org.apache.spark.api.java.JavaRDD;
0032 
0033 public class JavaALSSuite extends SharedSparkSession {
0034 
0035   private void validatePrediction(
0036     MatrixFactorizationModel model,
0037     int users,
0038     int products,
0039     double[] trueRatings,
0040     double matchThreshold,
0041     boolean implicitPrefs,
0042     double[] truePrefs) {
0043     List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products);
0044     for (int u = 0; u < users; ++u) {
0045       for (int p = 0; p < products; ++p) {
0046         localUsersProducts.add(new Tuple2<>(u, p));
0047       }
0048     }
0049     JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts);
0050     List<Rating> predictedRatings = model.predict(usersProducts).collect();
0051     Assert.assertEquals(users * products, predictedRatings.size());
0052     if (!implicitPrefs) {
0053       for (Rating r : predictedRatings) {
0054         double prediction = r.rating();
0055         double correct = trueRatings[r.product() * users + r.user()];
0056         Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
0057           prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
0058       }
0059     } else {
0060       // For implicit prefs we use the confidence-weighted RMSE to test
0061       // (ref Mahout's implicit ALS tests)
0062       double sqErr = 0.0;
0063       double denom = 0.0;
0064       for (Rating r : predictedRatings) {
0065         double prediction = r.rating();
0066         double truePref = truePrefs[r.product() * users + r.user()];
0067         double confidence = 1.0 +
0068           /* alpha = 1.0 * ... */ Math.abs(trueRatings[r.product() * users + r.user()]);
0069         double err = confidence * (truePref - prediction) * (truePref - prediction);
0070         sqErr += err;
0071         denom += confidence;
0072       }
0073       double rmse = Math.sqrt(sqErr / denom);
0074       Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
0075         rmse, matchThreshold), rmse < matchThreshold);
0076     }
0077   }
0078 
0079   @Test
0080   public void runALSUsingStaticMethods() {
0081     int features = 1;
0082     int iterations = 15;
0083     int users = 50;
0084     int products = 100;
0085     Tuple3<List<Rating>, double[], double[]> testData =
0086       ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
0087 
0088     JavaRDD<Rating> data = jsc.parallelize(testData._1());
0089     MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
0090     validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
0091   }
0092 
0093   @Test
0094   public void runALSUsingConstructor() {
0095     int features = 2;
0096     int iterations = 15;
0097     int users = 100;
0098     int products = 200;
0099     Tuple3<List<Rating>, double[], double[]> testData =
0100       ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
0101 
0102     JavaRDD<Rating> data = jsc.parallelize(testData._1());
0103 
0104     MatrixFactorizationModel model = new ALS().setRank(features)
0105       .setIterations(iterations)
0106       .run(data);
0107     validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
0108   }
0109 
0110   @Test
0111   public void runImplicitALSUsingStaticMethods() {
0112     int features = 1;
0113     int iterations = 15;
0114     int users = 80;
0115     int products = 160;
0116     Tuple3<List<Rating>, double[], double[]> testData =
0117       ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
0118 
0119     JavaRDD<Rating> data = jsc.parallelize(testData._1());
0120     MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
0121     validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
0122   }
0123 
0124   @Test
0125   public void runImplicitALSUsingConstructor() {
0126     int features = 2;
0127     int iterations = 15;
0128     int users = 100;
0129     int products = 200;
0130     Tuple3<List<Rating>, double[], double[]> testData =
0131       ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
0132 
0133     JavaRDD<Rating> data = jsc.parallelize(testData._1());
0134 
0135     MatrixFactorizationModel model = new ALS().setRank(features)
0136       .setIterations(iterations)
0137       .setImplicitPrefs(true)
0138       .run(data.rdd());
0139     validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
0140   }
0141 
0142   @Test
0143   public void runImplicitALSWithNegativeWeight() {
0144     int features = 2;
0145     int iterations = 15;
0146     int users = 80;
0147     int products = 160;
0148     Tuple3<List<Rating>, double[], double[]> testData =
0149       ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
0150 
0151     JavaRDD<Rating> data = jsc.parallelize(testData._1());
0152     MatrixFactorizationModel model = new ALS().setRank(features)
0153       .setIterations(iterations)
0154       .setImplicitPrefs(true)
0155       .setSeed(8675309L)
0156       .run(data.rdd());
0157     validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
0158   }
0159 
0160   @Test
0161   public void runRecommend() {
0162     int features = 5;
0163     int iterations = 10;
0164     int users = 200;
0165     int products = 50;
0166     List<Rating> testData = ALSSuite.generateRatingsAsJava(
0167       users, products, features, 0.7, true, false)._1();
0168     JavaRDD<Rating> data = jsc.parallelize(testData);
0169     MatrixFactorizationModel model = new ALS().setRank(features)
0170       .setIterations(iterations)
0171       .setImplicitPrefs(true)
0172       .setSeed(8675309L)
0173       .run(data.rdd());
0174     validateRecommendations(model.recommendProducts(1, 10), 10);
0175     validateRecommendations(model.recommendUsers(1, 20), 20);
0176   }
0177 
0178   private static void validateRecommendations(Rating[] recommendations, int howMany) {
0179     Assert.assertEquals(howMany, recommendations.length);
0180     for (int i = 1; i < recommendations.length; i++) {
0181       Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating());
0182     }
0183     Assert.assertTrue(recommendations[0].rating() > 0.7);
0184   }
0185 
0186 }