Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.mllib;
0019 
0020 // $example on$
0021 import scala.Tuple2;
0022 
0023 import org.apache.spark.api.java.*;
0024 import org.apache.spark.mllib.recommendation.ALS;
0025 import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
0026 import org.apache.spark.mllib.recommendation.Rating;
0027 import org.apache.spark.SparkConf;
0028 // $example off$
0029 
0030 public class JavaRecommendationExample {
0031   public static void main(String[] args) {
0032     // $example on$
0033     SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example");
0034     JavaSparkContext jsc = new JavaSparkContext(conf);
0035 
0036     // Load and parse the data
0037     String path = "data/mllib/als/test.data";
0038     JavaRDD<String> data = jsc.textFile(path);
0039     JavaRDD<Rating> ratings = data.map(s -> {
0040       String[] sarray = s.split(",");
0041       return new Rating(Integer.parseInt(sarray[0]),
0042         Integer.parseInt(sarray[1]),
0043         Double.parseDouble(sarray[2]));
0044     });
0045 
0046     // Build the recommendation model using ALS
0047     int rank = 10;
0048     int numIterations = 10;
0049     MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);
0050 
0051     // Evaluate the model on rating data
0052     JavaRDD<Tuple2<Object, Object>> userProducts =
0053       ratings.map(r -> new Tuple2<>(r.user(), r.product()));
0054     JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD(
0055       model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
0056           .map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))
0057     );
0058     JavaRDD<Tuple2<Double, Double>> ratesAndPreds = JavaPairRDD.fromJavaRDD(
0059         ratings.map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating())))
0060       .join(predictions).values();
0061     double MSE = ratesAndPreds.mapToDouble(pair -> {
0062       double err = pair._1() - pair._2();
0063       return err * err;
0064     }).mean();
0065     System.out.println("Mean Squared Error = " + MSE);
0066 
0067     // Save and load model
0068     model.save(jsc.sc(), "target/tmp/myCollaborativeFilter");
0069     MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(),
0070       "target/tmp/myCollaborativeFilter");
0071     // $example off$
0072 
0073     jsc.stop();
0074   }
0075 }