0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.mllib;
0019
0020
0021 import scala.Tuple2;
0022
0023 import org.apache.spark.api.java.*;
0024 import org.apache.spark.mllib.recommendation.ALS;
0025 import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
0026 import org.apache.spark.mllib.recommendation.Rating;
0027 import org.apache.spark.SparkConf;
0028
0029
0030 public class JavaRecommendationExample {
0031 public static void main(String[] args) {
0032
0033 SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example");
0034 JavaSparkContext jsc = new JavaSparkContext(conf);
0035
0036
0037 String path = "data/mllib/als/test.data";
0038 JavaRDD<String> data = jsc.textFile(path);
0039 JavaRDD<Rating> ratings = data.map(s -> {
0040 String[] sarray = s.split(",");
0041 return new Rating(Integer.parseInt(sarray[0]),
0042 Integer.parseInt(sarray[1]),
0043 Double.parseDouble(sarray[2]));
0044 });
0045
0046
0047 int rank = 10;
0048 int numIterations = 10;
0049 MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);
0050
0051
0052 JavaRDD<Tuple2<Object, Object>> userProducts =
0053 ratings.map(r -> new Tuple2<>(r.user(), r.product()));
0054 JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD(
0055 model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
0056 .map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))
0057 );
0058 JavaRDD<Tuple2<Double, Double>> ratesAndPreds = JavaPairRDD.fromJavaRDD(
0059 ratings.map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating())))
0060 .join(predictions).values();
0061 double MSE = ratesAndPreds.mapToDouble(pair -> {
0062 double err = pair._1() - pair._2();
0063 return err * err;
0064 }).mean();
0065 System.out.println("Mean Squared Error = " + MSE);
0066
0067
0068 model.save(jsc.sc(), "target/tmp/myCollaborativeFilter");
0069 MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(),
0070 "target/tmp/myCollaborativeFilter");
0071
0072
0073 jsc.stop();
0074 }
0075 }