0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.classification.LogisticRegression;
0022 import org.apache.spark.ml.classification.OneVsRest;
0023 import org.apache.spark.ml.classification.OneVsRestModel;
0024 import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
0025 import org.apache.spark.sql.Dataset;
0026 import org.apache.spark.sql.Row;
0027
0028 import org.apache.spark.sql.SparkSession;
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039 public class JavaOneVsRestExample {
0040 public static void main(String[] args) {
0041 SparkSession spark = SparkSession
0042 .builder()
0043 .appName("JavaOneVsRestExample")
0044 .getOrCreate();
0045
0046
0047
0048 Dataset<Row> inputData = spark.read().format("libsvm")
0049 .load("data/mllib/sample_multiclass_classification_data.txt");
0050
0051
0052 Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
0053 Dataset<Row> train = tmp[0];
0054 Dataset<Row> test = tmp[1];
0055
0056
0057 LogisticRegression classifier = new LogisticRegression()
0058 .setMaxIter(10)
0059 .setTol(1E-6)
0060 .setFitIntercept(true);
0061
0062
0063 OneVsRest ovr = new OneVsRest().setClassifier(classifier);
0064
0065
0066 OneVsRestModel ovrModel = ovr.fit(train);
0067
0068
0069 Dataset<Row> predictions = ovrModel.transform(test)
0070 .select("prediction", "label");
0071
0072
0073 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
0074 .setMetricName("accuracy");
0075
0076
0077 double accuracy = evaluator.evaluate(predictions);
0078 System.out.println("Test Error = " + (1 - accuracy));
0079
0080
0081 spark.stop();
0082 }
0083
0084 }