0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.evaluation;
0019
0020 import java.io.IOException;
0021 import java.util.Arrays;
0022 import java.util.List;
0023
0024 import scala.Tuple2;
0025 import scala.Tuple2$;
0026
0027 import org.junit.Assert;
0028 import org.junit.Test;
0029
0030 import org.apache.spark.SharedSparkSession;
0031 import org.apache.spark.api.java.JavaRDD;
0032
0033 public class JavaRankingMetricsSuite extends SharedSparkSession {
0034 private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels;
0035
0036 @Override
0037 public void setUp() throws IOException {
0038 super.setUp();
0039 predictionAndLabels = jsc.parallelize(Arrays.asList(
0040 Tuple2$.MODULE$.apply(
0041 Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
0042 Tuple2$.MODULE$.apply(
0043 Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
0044 Tuple2$.MODULE$.apply(
0045 Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
0046 }
0047
0048 @Test
0049 public void rankingMetrics() {
0050 @SuppressWarnings("unchecked")
0051 RankingMetrics<?> metrics = RankingMetrics.of(predictionAndLabels);
0052 Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
0053 Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5);
0054 }
0055 }