0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.sql.Dataset;
0022 import org.apache.spark.sql.Row;
0023 import org.apache.spark.sql.SparkSession;
0024 import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
0025 import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
0026 import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
0027
0028
0029
0030
0031
0032 public class JavaMultilayerPerceptronClassifierExample {
0033
0034 public static void main(String[] args) {
0035 SparkSession spark = SparkSession
0036 .builder()
0037 .appName("JavaMultilayerPerceptronClassifierExample")
0038 .getOrCreate();
0039
0040
0041
0042 String path = "data/mllib/sample_multiclass_classification_data.txt";
0043 Dataset<Row> dataFrame = spark.read().format("libsvm").load(path);
0044
0045
0046 Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
0047 Dataset<Row> train = splits[0];
0048 Dataset<Row> test = splits[1];
0049
0050
0051
0052
0053 int[] layers = new int[] {4, 5, 4, 3};
0054
0055
0056 MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
0057 .setLayers(layers)
0058 .setBlockSize(128)
0059 .setSeed(1234L)
0060 .setMaxIter(100);
0061
0062
0063 MultilayerPerceptronClassificationModel model = trainer.fit(train);
0064
0065
0066 Dataset<Row> result = model.transform(test);
0067 Dataset<Row> predictionAndLabels = result.select("prediction", "label");
0068 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
0069 .setMetricName("accuracy");
0070
0071 System.out.println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels));
0072
0073
0074 spark.stop();
0075 }
0076 }