0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020 import org.apache.spark.sql.Dataset;
0021 import org.apache.spark.sql.Row;
0022 import org.apache.spark.sql.SparkSession;
0023
0024
0025 import java.io.Serializable;
0026
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.ml.evaluation.RegressionEvaluator;
0029 import org.apache.spark.ml.recommendation.ALS;
0030 import org.apache.spark.ml.recommendation.ALSModel;
0031
0032
0033 public class JavaALSExample {
0034
0035
0036 public static class Rating implements Serializable {
0037 private int userId;
0038 private int movieId;
0039 private float rating;
0040 private long timestamp;
0041
0042 public Rating() {}
0043
0044 public Rating(int userId, int movieId, float rating, long timestamp) {
0045 this.userId = userId;
0046 this.movieId = movieId;
0047 this.rating = rating;
0048 this.timestamp = timestamp;
0049 }
0050
0051 public int getUserId() {
0052 return userId;
0053 }
0054
0055 public int getMovieId() {
0056 return movieId;
0057 }
0058
0059 public float getRating() {
0060 return rating;
0061 }
0062
0063 public long getTimestamp() {
0064 return timestamp;
0065 }
0066
0067 public static Rating parseRating(String str) {
0068 String[] fields = str.split("::");
0069 if (fields.length != 4) {
0070 throw new IllegalArgumentException("Each line must contain 4 fields");
0071 }
0072 int userId = Integer.parseInt(fields[0]);
0073 int movieId = Integer.parseInt(fields[1]);
0074 float rating = Float.parseFloat(fields[2]);
0075 long timestamp = Long.parseLong(fields[3]);
0076 return new Rating(userId, movieId, rating, timestamp);
0077 }
0078 }
0079
0080
0081 public static void main(String[] args) {
0082 SparkSession spark = SparkSession
0083 .builder()
0084 .appName("JavaALSExample")
0085 .getOrCreate();
0086
0087
0088 JavaRDD<Rating> ratingsRDD = spark
0089 .read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD()
0090 .map(Rating::parseRating);
0091 Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
0092 Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
0093 Dataset<Row> training = splits[0];
0094 Dataset<Row> test = splits[1];
0095
0096
0097 ALS als = new ALS()
0098 .setMaxIter(5)
0099 .setRegParam(0.01)
0100 .setUserCol("userId")
0101 .setItemCol("movieId")
0102 .setRatingCol("rating");
0103 ALSModel model = als.fit(training);
0104
0105
0106
0107 model.setColdStartStrategy("drop");
0108 Dataset<Row> predictions = model.transform(test);
0109
0110 RegressionEvaluator evaluator = new RegressionEvaluator()
0111 .setMetricName("rmse")
0112 .setLabelCol("rating")
0113 .setPredictionCol("prediction");
0114 double rmse = evaluator.evaluate(predictions);
0115 System.out.println("Root-mean-square error = " + rmse);
0116
0117
0118 Dataset<Row> userRecs = model.recommendForAllUsers(10);
0119
0120 Dataset<Row> movieRecs = model.recommendForAllItems(10);
0121
0122
0123 Dataset<Row> users = ratings.select(als.getUserCol()).distinct().limit(3);
0124 Dataset<Row> userSubsetRecs = model.recommendForUserSubset(users, 10);
0125
0126 Dataset<Row> movies = ratings.select(als.getItemCol()).distinct().limit(3);
0127 Dataset<Row> movieSubSetRecs = model.recommendForItemSubset(movies, 10);
0128
0129 userRecs.show();
0130 movieRecs.show();
0131 userSubsetRecs.show();
0132 movieSubSetRecs.show();
0133
0134 spark.stop();
0135 }
0136 }