Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.mllib;
0019 
0020 import org.apache.spark.SparkConf;
0021 import org.apache.spark.api.java.JavaRDD;
0022 import org.apache.spark.api.java.JavaSparkContext;
0023 import org.apache.spark.api.java.function.Function;
0024 
0025 import org.apache.spark.mllib.recommendation.ALS;
0026 import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
0027 import org.apache.spark.mllib.recommendation.Rating;
0028 
0029 import java.util.Arrays;
0030 import java.util.regex.Pattern;
0031 
0032 import scala.Tuple2;
0033 
0034 /**
0035  * Example using MLlib ALS from Java.
0036  */
0037 public final class JavaALS {
0038 
0039   static class ParseRating implements Function<String, Rating> {
0040     private static final Pattern COMMA = Pattern.compile(",");
0041 
0042     @Override
0043     public Rating call(String line) {
0044       String[] tok = COMMA.split(line);
0045       int x = Integer.parseInt(tok[0]);
0046       int y = Integer.parseInt(tok[1]);
0047       double rating = Double.parseDouble(tok[2]);
0048       return new Rating(x, y, rating);
0049     }
0050   }
0051 
0052   static class FeaturesToString implements Function<Tuple2<Object, double[]>, String> {
0053     @Override
0054     public String call(Tuple2<Object, double[]> element) {
0055       return element._1() + "," + Arrays.toString(element._2());
0056     }
0057   }
0058 
0059   public static void main(String[] args) {
0060 
0061     if (args.length < 4) {
0062       System.err.println(
0063         "Usage: JavaALS <ratings_file> <rank> <iterations> <output_dir> [<blocks>]");
0064       System.exit(1);
0065     }
0066     SparkConf sparkConf = new SparkConf().setAppName("JavaALS");
0067     int rank = Integer.parseInt(args[1]);
0068     int iterations = Integer.parseInt(args[2]);
0069     String outputDir = args[3];
0070     int blocks = -1;
0071     if (args.length == 5) {
0072       blocks = Integer.parseInt(args[4]);
0073     }
0074 
0075     JavaSparkContext sc = new JavaSparkContext(sparkConf);
0076     JavaRDD<String> lines = sc.textFile(args[0]);
0077 
0078     JavaRDD<Rating> ratings = lines.map(new ParseRating());
0079 
0080     MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks);
0081 
0082     model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
0083         outputDir + "/userFeatures");
0084     model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
0085         outputDir + "/productFeatures");
0086     System.out.println("Final user/product features written to " + outputDir);
0087 
0088     sc.stop();
0089   }
0090 }