0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.classification;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.mllib.linalg.Vector;
0029 import org.apache.spark.mllib.linalg.Vectors;
0030 import org.apache.spark.mllib.regression.LabeledPoint;
0031
0032
0033 public class JavaNaiveBayesSuite extends SharedSparkSession {
0034
0035 private static final List<LabeledPoint> POINTS = Arrays.asList(
0036 new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),
0037 new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)),
0038 new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)),
0039 new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)),
0040 new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)),
0041 new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0))
0042 );
0043
0044 private static int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
0045 int correct = 0;
0046 for (LabeledPoint p : points) {
0047 if (model.predict(p.features()) == p.label()) {
0048 correct += 1;
0049 }
0050 }
0051 return correct;
0052 }
0053
0054 @Test
0055 public void runUsingConstructor() {
0056 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
0057
0058 NaiveBayes nb = new NaiveBayes().setLambda(1.0);
0059 NaiveBayesModel model = nb.run(testRDD.rdd());
0060
0061 int numAccurate = validatePrediction(POINTS, model);
0062 Assert.assertEquals(POINTS.size(), numAccurate);
0063 }
0064
0065 @Test
0066 public void runUsingStaticMethods() {
0067 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
0068
0069 NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
0070 int numAccurate1 = validatePrediction(POINTS, model1);
0071 Assert.assertEquals(POINTS.size(), numAccurate1);
0072
0073 NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
0074 int numAccurate2 = validatePrediction(POINTS, model2);
0075 Assert.assertEquals(POINTS.size(), numAccurate2);
0076 }
0077
0078 @Test
0079 public void testPredictJavaRDD() {
0080 JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
0081 NaiveBayesModel model = NaiveBayes.train(examples.rdd());
0082 JavaRDD<Vector> vectors = examples.map(LabeledPoint::features);
0083 JavaRDD<Double> predictions = model.predict(vectors);
0084
0085 predictions.first();
0086 }
0087
0088 @Test
0089 public void testModelTypeSetters() {
0090 NaiveBayes nb = new NaiveBayes()
0091 .setModelType("bernoulli")
0092 .setModelType("multinomial");
0093 }
0094 }