0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020 import org.apache.spark.ml.Pipeline;
0021 import org.apache.spark.ml.PipelineModel;
0022 import org.apache.spark.ml.PipelineStage;
0023 import org.apache.spark.ml.classification.DecisionTreeClassifier;
0024 import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
0025 import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
0026 import org.apache.spark.ml.feature.*;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029 import org.apache.spark.sql.SparkSession;
0030
0031
0032 public class JavaDecisionTreeClassificationExample {
0033 public static void main(String[] args) {
0034 SparkSession spark = SparkSession
0035 .builder()
0036 .appName("JavaDecisionTreeClassificationExample")
0037 .getOrCreate();
0038
0039
0040
0041 Dataset<Row> data = spark
0042 .read()
0043 .format("libsvm")
0044 .load("data/mllib/sample_libsvm_data.txt");
0045
0046
0047
0048 StringIndexerModel labelIndexer = new StringIndexer()
0049 .setInputCol("label")
0050 .setOutputCol("indexedLabel")
0051 .fit(data);
0052
0053
0054 VectorIndexerModel featureIndexer = new VectorIndexer()
0055 .setInputCol("features")
0056 .setOutputCol("indexedFeatures")
0057 .setMaxCategories(4)
0058 .fit(data);
0059
0060
0061 Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
0062 Dataset<Row> trainingData = splits[0];
0063 Dataset<Row> testData = splits[1];
0064
0065
0066 DecisionTreeClassifier dt = new DecisionTreeClassifier()
0067 .setLabelCol("indexedLabel")
0068 .setFeaturesCol("indexedFeatures");
0069
0070
0071 IndexToString labelConverter = new IndexToString()
0072 .setInputCol("prediction")
0073 .setOutputCol("predictedLabel")
0074 .setLabels(labelIndexer.labelsArray()[0]);
0075
0076
0077 Pipeline pipeline = new Pipeline()
0078 .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});
0079
0080
0081 PipelineModel model = pipeline.fit(trainingData);
0082
0083
0084 Dataset<Row> predictions = model.transform(testData);
0085
0086
0087 predictions.select("predictedLabel", "label", "features").show(5);
0088
0089
0090 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
0091 .setLabelCol("indexedLabel")
0092 .setPredictionCol("prediction")
0093 .setMetricName("accuracy");
0094 double accuracy = evaluator.evaluate(predictions);
0095 System.out.println("Test Error = " + (1.0 - accuracy));
0096
0097 DecisionTreeClassificationModel treeModel =
0098 (DecisionTreeClassificationModel) (model.stages()[2]);
0099 System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
0100
0101
0102 spark.stop();
0103 }
0104 }