0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.mllib;
0019
0020 import org.apache.spark.SparkConf;
0021 import org.apache.spark.api.java.JavaRDD;
0022 import org.apache.spark.api.java.JavaSparkContext;
0023 import org.apache.spark.api.java.function.Function;
0024
0025 import org.apache.spark.mllib.recommendation.ALS;
0026 import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
0027 import org.apache.spark.mllib.recommendation.Rating;
0028
0029 import java.util.Arrays;
0030 import java.util.regex.Pattern;
0031
0032 import scala.Tuple2;
0033
0034
0035
0036
0037 public final class JavaALS {
0038
0039 static class ParseRating implements Function<String, Rating> {
0040 private static final Pattern COMMA = Pattern.compile(",");
0041
0042 @Override
0043 public Rating call(String line) {
0044 String[] tok = COMMA.split(line);
0045 int x = Integer.parseInt(tok[0]);
0046 int y = Integer.parseInt(tok[1]);
0047 double rating = Double.parseDouble(tok[2]);
0048 return new Rating(x, y, rating);
0049 }
0050 }
0051
0052 static class FeaturesToString implements Function<Tuple2<Object, double[]>, String> {
0053 @Override
0054 public String call(Tuple2<Object, double[]> element) {
0055 return element._1() + "," + Arrays.toString(element._2());
0056 }
0057 }
0058
0059 public static void main(String[] args) {
0060
0061 if (args.length < 4) {
0062 System.err.println(
0063 "Usage: JavaALS <ratings_file> <rank> <iterations> <output_dir> [<blocks>]");
0064 System.exit(1);
0065 }
0066 SparkConf sparkConf = new SparkConf().setAppName("JavaALS");
0067 int rank = Integer.parseInt(args[1]);
0068 int iterations = Integer.parseInt(args[2]);
0069 String outputDir = args[3];
0070 int blocks = -1;
0071 if (args.length == 5) {
0072 blocks = Integer.parseInt(args[4]);
0073 }
0074
0075 JavaSparkContext sc = new JavaSparkContext(sparkConf);
0076 JavaRDD<String> lines = sc.textFile(args[0]);
0077
0078 JavaRDD<Rating> ratings = lines.map(new ParseRating());
0079
0080 MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks);
0081
0082 model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
0083 outputDir + "/userFeatures");
0084 model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
0085 outputDir + "/productFeatures");
0086 System.out.println("Final user/product features written to " + outputDir);
0087
0088 sc.stop();
0089 }
0090 }