Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.classification;
0019 
0020 import java.util.Arrays;
0021 import java.util.List;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.mllib.linalg.Vector;
0029 import org.apache.spark.mllib.linalg.Vectors;
0030 import org.apache.spark.mllib.regression.LabeledPoint;
0031 
0032 
0033 public class JavaNaiveBayesSuite extends SharedSparkSession {
0034 
0035   private static final List<LabeledPoint> POINTS = Arrays.asList(
0036     new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),
0037     new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)),
0038     new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)),
0039     new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)),
0040     new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)),
0041     new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0))
0042   );
0043 
0044   private static int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
0045     int correct = 0;
0046     for (LabeledPoint p : points) {
0047       if (model.predict(p.features()) == p.label()) {
0048         correct += 1;
0049       }
0050     }
0051     return correct;
0052   }
0053 
0054   @Test
0055   public void runUsingConstructor() {
0056     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
0057 
0058     NaiveBayes nb = new NaiveBayes().setLambda(1.0);
0059     NaiveBayesModel model = nb.run(testRDD.rdd());
0060 
0061     int numAccurate = validatePrediction(POINTS, model);
0062     Assert.assertEquals(POINTS.size(), numAccurate);
0063   }
0064 
0065   @Test
0066   public void runUsingStaticMethods() {
0067     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
0068 
0069     NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
0070     int numAccurate1 = validatePrediction(POINTS, model1);
0071     Assert.assertEquals(POINTS.size(), numAccurate1);
0072 
0073     NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
0074     int numAccurate2 = validatePrediction(POINTS, model2);
0075     Assert.assertEquals(POINTS.size(), numAccurate2);
0076   }
0077 
0078   @Test
0079   public void testPredictJavaRDD() {
0080     JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
0081     NaiveBayesModel model = NaiveBayes.train(examples.rdd());
0082     JavaRDD<Vector> vectors = examples.map(LabeledPoint::features);
0083     JavaRDD<Double> predictions = model.predict(vectors);
0084     // Should be able to get the first prediction.
0085     predictions.first();
0086   }
0087 
0088   @Test
0089   public void testModelTypeSetters() {
0090     NaiveBayes nb = new NaiveBayes()
0091       .setModelType("bernoulli")
0092       .setModelType("multinomial");
0093   }
0094 }