0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.mllib;
0019
0020
0021 import java.util.Arrays;
0022 import java.util.List;
0023
0024 import scala.Tuple2;
0025
0026 import org.apache.spark.api.java.*;
0027 import org.apache.spark.mllib.evaluation.MultilabelMetrics;
0028 import org.apache.spark.SparkConf;
0029
0030
0031 public class JavaMultiLabelClassificationMetricsExample {
0032 public static void main(String[] args) {
0033 SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example");
0034 JavaSparkContext sc = new JavaSparkContext(conf);
0035
0036 List<Tuple2<double[], double[]>> data = Arrays.asList(
0037 new Tuple2<>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
0038 new Tuple2<>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
0039 new Tuple2<>(new double[]{}, new double[]{0.0}),
0040 new Tuple2<>(new double[]{2.0}, new double[]{2.0}),
0041 new Tuple2<>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
0042 new Tuple2<>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
0043 new Tuple2<>(new double[]{1.0}, new double[]{1.0, 2.0})
0044 );
0045 JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data);
0046
0047
0048 MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd());
0049
0050
0051 System.out.format("Recall = %f\n", metrics.recall());
0052 System.out.format("Precision = %f\n", metrics.precision());
0053 System.out.format("F1 measure = %f\n", metrics.f1Measure());
0054 System.out.format("Accuracy = %f\n", metrics.accuracy());
0055
0056
0057 for (int i = 0; i < metrics.labels().length - 1; i++) {
0058 System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(
0059 metrics.labels()[i]));
0060 System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(
0061 metrics.labels()[i]));
0062 System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(
0063 metrics.labels()[i]));
0064 }
0065
0066
0067 System.out.format("Micro recall = %f\n", metrics.microRecall());
0068 System.out.format("Micro precision = %f\n", metrics.microPrecision());
0069 System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure());
0070
0071
0072 System.out.format("Hamming loss = %f\n", metrics.hammingLoss());
0073
0074
0075 System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy());
0076
0077
0078 sc.stop();
0079 }
0080 }