Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.mllib;
0019 
0020 // $example on$
0021 import java.util.*;
0022 
0023 import scala.Tuple2;
0024 
0025 import org.apache.spark.api.java.*;
0026 import org.apache.spark.mllib.evaluation.RegressionMetrics;
0027 import org.apache.spark.mllib.evaluation.RankingMetrics;
0028 import org.apache.spark.mllib.recommendation.ALS;
0029 import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
0030 import org.apache.spark.mllib.recommendation.Rating;
0031 // $example off$
0032 import org.apache.spark.SparkConf;
0033 
0034 public class JavaRankingMetricsExample {
0035   public static void main(String[] args) {
0036     SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example");
0037     JavaSparkContext sc = new JavaSparkContext(conf);
0038     // $example on$
0039     String path = "data/mllib/sample_movielens_data.txt";
0040     JavaRDD<String> data = sc.textFile(path);
0041     JavaRDD<Rating> ratings = data.map(line -> {
0042         String[] parts = line.split("::");
0043         return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double
0044             .parseDouble(parts[2]) - 2.5);
0045       });
0046     ratings.cache();
0047 
0048     // Train an ALS model
0049     MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01);
0050 
0051     // Get top 10 recommendations for every user and scale ratings from 0 to 1
0052     JavaRDD<Tuple2<Object, Rating[]>> userRecs = model.recommendProductsForUsers(10).toJavaRDD();
0053     JavaRDD<Tuple2<Object, Rating[]>> userRecsScaled = userRecs.map(t -> {
0054         Rating[] scaledRatings = new Rating[t._2().length];
0055         for (int i = 0; i < scaledRatings.length; i++) {
0056           double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0);
0057           scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating);
0058         }
0059         return new Tuple2<>(t._1(), scaledRatings);
0060       });
0061     JavaPairRDD<Object, Rating[]> userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled);
0062 
0063     // Map ratings to 1 or 0, 1 indicating a movie that should be recommended
0064     JavaRDD<Rating> binarizedRatings = ratings.map(r -> {
0065         double binaryRating;
0066         if (r.rating() > 0.0) {
0067           binaryRating = 1.0;
0068         } else {
0069           binaryRating = 0.0;
0070         }
0071         return new Rating(r.user(), r.product(), binaryRating);
0072       });
0073 
0074     // Group ratings by common user
0075     JavaPairRDD<Object, Iterable<Rating>> userMovies = binarizedRatings.groupBy(Rating::user);
0076 
0077     // Get true relevant documents from all user ratings
0078     JavaPairRDD<Object, List<Integer>> userMoviesList = userMovies.mapValues(docs -> {
0079         List<Integer> products = new ArrayList<>();
0080         for (Rating r : docs) {
0081           if (r.rating() > 0.0) {
0082             products.add(r.product());
0083           }
0084         }
0085         return products;
0086       });
0087 
0088     // Extract the product id from each recommendation
0089     JavaPairRDD<Object, List<Integer>> userRecommendedList = userRecommended.mapValues(docs -> {
0090         List<Integer> products = new ArrayList<>();
0091         for (Rating r : docs) {
0092           products.add(r.product());
0093         }
0094         return products;
0095       });
0096     JavaRDD<Tuple2<List<Integer>, List<Integer>>> relevantDocs = userMoviesList.join(
0097       userRecommendedList).values();
0098 
0099     // Instantiate the metrics object
0100     RankingMetrics<Integer> metrics = RankingMetrics.of(relevantDocs);
0101 
0102     // Precision, NDCG and Recall at k
0103     Integer[] kVector = {1, 3, 5};
0104     for (Integer k : kVector) {
0105       System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k));
0106       System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k));
0107       System.out.format("Recall at %d = %f\n", k, metrics.recallAt(k));
0108     }
0109 
0110     // Mean average precision
0111     System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision());
0112 
0113     //Mean average precision at k
0114     System.out.format("Mean average precision at 2 = %f\n", metrics.meanAveragePrecisionAt(2));
0115 
0116     // Evaluate the model using numerical ratings and regression metrics
0117     JavaRDD<Tuple2<Object, Object>> userProducts =
0118         ratings.map(r -> new Tuple2<>(r.user(), r.product()));
0119 
0120     JavaPairRDD<Tuple2<Integer, Integer>, Object> predictions = JavaPairRDD.fromJavaRDD(
0121       model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(r ->
0122         new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating())));
0123     JavaRDD<Tuple2<Object, Object>> ratesAndPreds =
0124       JavaPairRDD.fromJavaRDD(ratings.map(r ->
0125         new Tuple2<Tuple2<Integer, Integer>, Object>(
0126           new Tuple2<>(r.user(), r.product()),
0127           r.rating())
0128       )).join(predictions).values();
0129 
0130     // Create regression metrics object
0131     RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd());
0132 
0133     // Root mean squared error
0134     System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError());
0135 
0136     // R-squared
0137     System.out.format("R-squared = %f\n", regressionMetrics.r2());
0138     // $example off$
0139 
0140     sc.stop();
0141   }
0142 }