0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.mllib;
0019
0020
0021 import scala.Tuple2;
0022
0023 import org.apache.spark.api.java.*;
0024 import org.apache.spark.mllib.classification.LogisticRegressionModel;
0025 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
0026 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
0027 import org.apache.spark.mllib.regression.LabeledPoint;
0028 import org.apache.spark.mllib.util.MLUtils;
0029
0030 import org.apache.spark.SparkConf;
0031 import org.apache.spark.SparkContext;
0032
0033 public class JavaBinaryClassificationMetricsExample {
0034 public static void main(String[] args) {
0035 SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");
0036 SparkContext sc = new SparkContext(conf);
0037
0038 String path = "data/mllib/sample_binary_classification_data.txt";
0039 JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
0040
0041
0042 JavaRDD<LabeledPoint>[] splits =
0043 data.randomSplit(new double[]{0.6, 0.4}, 11L);
0044 JavaRDD<LabeledPoint> training = splits[0].cache();
0045 JavaRDD<LabeledPoint> test = splits[1];
0046
0047
0048 LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
0049 .setNumClasses(2)
0050 .run(training.rdd());
0051
0052
0053 model.clearThreshold();
0054
0055
0056 JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p ->
0057 new Tuple2<>(model.predict(p.features()), p.label()));
0058
0059
0060 BinaryClassificationMetrics metrics =
0061 new BinaryClassificationMetrics(predictionAndLabels.rdd());
0062
0063
0064 JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
0065 System.out.println("Precision by threshold: " + precision.collect());
0066
0067
0068 JavaRDD<?> recall = metrics.recallByThreshold().toJavaRDD();
0069 System.out.println("Recall by threshold: " + recall.collect());
0070
0071
0072 JavaRDD<?> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
0073 System.out.println("F1 Score by threshold: " + f1Score.collect());
0074
0075 JavaRDD<?> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
0076 System.out.println("F2 Score by threshold: " + f2Score.collect());
0077
0078
0079 JavaRDD<?> prc = metrics.pr().toJavaRDD();
0080 System.out.println("Precision-recall curve: " + prc.collect());
0081
0082
0083 JavaRDD<Double> thresholds = precision.map(t -> Double.parseDouble(t._1().toString()));
0084
0085
0086 JavaRDD<?> roc = metrics.roc().toJavaRDD();
0087 System.out.println("ROC curve: " + roc.collect());
0088
0089
0090 System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());
0091
0092
0093 System.out.println("Area under ROC = " + metrics.areaUnderROC());
0094
0095
0096 model.save(sc, "target/tmp/LogisticRegressionModel");
0097 LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel");
0098
0099
0100 sc.stop();
0101 }
0102 }