0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.classification.NaiveBayes;
0022 import org.apache.spark.ml.classification.NaiveBayesModel;
0023 import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
0024 import org.apache.spark.sql.Dataset;
0025 import org.apache.spark.sql.Row;
0026 import org.apache.spark.sql.SparkSession;
0027
0028
0029
0030
0031
0032 public class JavaNaiveBayesExample {
0033
0034 public static void main(String[] args) {
0035 SparkSession spark = SparkSession
0036 .builder()
0037 .appName("JavaNaiveBayesExample")
0038 .getOrCreate();
0039
0040
0041
0042 Dataset<Row> dataFrame =
0043 spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
0044
0045 Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
0046 Dataset<Row> train = splits[0];
0047 Dataset<Row> test = splits[1];
0048
0049
0050 NaiveBayes nb = new NaiveBayes();
0051
0052
0053 NaiveBayesModel model = nb.fit(train);
0054
0055
0056 Dataset<Row> predictions = model.transform(test);
0057 predictions.show();
0058
0059
0060 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
0061 .setLabelCol("label")
0062 .setPredictionCol("prediction")
0063 .setMetricName("accuracy");
0064 double accuracy = evaluator.evaluate(predictions);
0065 System.out.println("Test set accuracy = " + accuracy);
0066
0067
0068 spark.stop();
0069 }
0070 }