0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.classification;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Test;
0024 import static org.junit.Assert.assertEquals;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.ml.linalg.VectorUDT;
0028 import org.apache.spark.ml.linalg.Vectors;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 import org.apache.spark.sql.RowFactory;
0032 import org.apache.spark.sql.types.DataTypes;
0033 import org.apache.spark.sql.types.Metadata;
0034 import org.apache.spark.sql.types.StructField;
0035 import org.apache.spark.sql.types.StructType;
0036
0037 public class JavaNaiveBayesSuite extends SharedSparkSession {
0038
0039 public void validatePrediction(Dataset<Row> predictionAndLabels) {
0040 for (Row r : predictionAndLabels.collectAsList()) {
0041 double prediction = r.getAs(0);
0042 double label = r.getAs(1);
0043 assertEquals(label, prediction, 1E-5);
0044 }
0045 }
0046
0047 @Test
0048 public void naiveBayesDefaultParams() {
0049 NaiveBayes nb = new NaiveBayes();
0050 assertEquals("label", nb.getLabelCol());
0051 assertEquals("features", nb.getFeaturesCol());
0052 assertEquals("prediction", nb.getPredictionCol());
0053 assertEquals(1.0, nb.getSmoothing(), 1E-5);
0054 assertEquals("multinomial", nb.getModelType());
0055 }
0056
0057 @Test
0058 public void testNaiveBayes() {
0059 List<Row> data = Arrays.asList(
0060 RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
0061 RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
0062 RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
0063 RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
0064 RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
0065 RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)));
0066
0067 StructType schema = new StructType(new StructField[]{
0068 new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
0069 new StructField("features", new VectorUDT(), false, Metadata.empty())
0070 });
0071
0072 Dataset<Row> dataset = spark.createDataFrame(data, schema);
0073 NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
0074 NaiveBayesModel model = nb.fit(dataset);
0075
0076 Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label");
0077 validatePrediction(predictionAndLabels);
0078 }
0079 }