Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.mllib;
0019 
0020 import org.apache.spark.SparkConf;
0021 import org.apache.spark.SparkContext;
0022 
0023 // $example on$
0024 import scala.Tuple2;
0025 
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.mllib.classification.SVMModel;
0028 import org.apache.spark.mllib.classification.SVMWithSGD;
0029 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
0030 import org.apache.spark.mllib.regression.LabeledPoint;
0031 import org.apache.spark.mllib.util.MLUtils;
0032 // $example off$
0033 
0034 /**
0035  * Example for SVMWithSGD.
0036  */
0037 public class JavaSVMWithSGDExample {
0038   public static void main(String[] args) {
0039     SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample");
0040     SparkContext sc = new SparkContext(conf);
0041     // $example on$
0042     String path = "data/mllib/sample_libsvm_data.txt";
0043     JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
0044 
0045     // Split initial RDD into two... [60% training data, 40% testing data].
0046     JavaRDD<LabeledPoint> training = data.sample(false, 0.6, 11L);
0047     training.cache();
0048     JavaRDD<LabeledPoint> test = data.subtract(training);
0049 
0050     // Run training algorithm to build the model.
0051     int numIterations = 100;
0052     SVMModel model = SVMWithSGD.train(training.rdd(), numIterations);
0053 
0054     // Clear the default threshold.
0055     model.clearThreshold();
0056 
0057     // Compute raw scores on the test set.
0058     JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test.map(p ->
0059       new Tuple2<>(model.predict(p.features()), p.label()));
0060 
0061     // Get evaluation metrics.
0062     BinaryClassificationMetrics metrics =
0063       new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels));
0064     double auROC = metrics.areaUnderROC();
0065 
0066     System.out.println("Area under ROC = " + auROC);
0067 
0068     // Save and load model
0069     model.save(sc, "target/tmp/javaSVMWithSGDModel");
0070     SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel");
0071     // $example off$
0072 
0073     sc.stop();
0074   }
0075 }