0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.mllib;
0019
0020 import org.apache.spark.SparkConf;
0021 import org.apache.spark.SparkContext;
0022
0023
0024 import scala.Tuple2;
0025
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.mllib.classification.SVMModel;
0028 import org.apache.spark.mllib.classification.SVMWithSGD;
0029 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
0030 import org.apache.spark.mllib.regression.LabeledPoint;
0031 import org.apache.spark.mllib.util.MLUtils;
0032
0033
0034
0035
0036
0037 public class JavaSVMWithSGDExample {
0038 public static void main(String[] args) {
0039 SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample");
0040 SparkContext sc = new SparkContext(conf);
0041
0042 String path = "data/mllib/sample_libsvm_data.txt";
0043 JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
0044
0045
0046 JavaRDD<LabeledPoint> training = data.sample(false, 0.6, 11L);
0047 training.cache();
0048 JavaRDD<LabeledPoint> test = data.subtract(training);
0049
0050
0051 int numIterations = 100;
0052 SVMModel model = SVMWithSGD.train(training.rdd(), numIterations);
0053
0054
0055 model.clearThreshold();
0056
0057
0058 JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test.map(p ->
0059 new Tuple2<>(model.predict(p.features()), p.label()));
0060
0061
0062 BinaryClassificationMetrics metrics =
0063 new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels));
0064 double auROC = metrics.areaUnderROC();
0065
0066 System.out.println("Area under ROC = " + auROC);
0067
0068
0069 model.save(sc, "target/tmp/javaSVMWithSGDModel");
0070 SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel");
0071
0072
0073 sc.stop();
0074 }
0075 }