0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.Pipeline;
0022 import org.apache.spark.ml.PipelineModel;
0023 import org.apache.spark.ml.PipelineStage;
0024 import org.apache.spark.ml.evaluation.RegressionEvaluator;
0025 import org.apache.spark.ml.feature.VectorIndexer;
0026 import org.apache.spark.ml.feature.VectorIndexerModel;
0027 import org.apache.spark.ml.regression.RandomForestRegressionModel;
0028 import org.apache.spark.ml.regression.RandomForestRegressor;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 import org.apache.spark.sql.SparkSession;
0032
0033
0034 public class JavaRandomForestRegressorExample {
0035 public static void main(String[] args) {
0036 SparkSession spark = SparkSession
0037 .builder()
0038 .appName("JavaRandomForestRegressorExample")
0039 .getOrCreate();
0040
0041
0042
0043 Dataset<Row> data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
0044
0045
0046
0047 VectorIndexerModel featureIndexer = new VectorIndexer()
0048 .setInputCol("features")
0049 .setOutputCol("indexedFeatures")
0050 .setMaxCategories(4)
0051 .fit(data);
0052
0053
0054 Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
0055 Dataset<Row> trainingData = splits[0];
0056 Dataset<Row> testData = splits[1];
0057
0058
0059 RandomForestRegressor rf = new RandomForestRegressor()
0060 .setLabelCol("label")
0061 .setFeaturesCol("indexedFeatures");
0062
0063
0064 Pipeline pipeline = new Pipeline()
0065 .setStages(new PipelineStage[] {featureIndexer, rf});
0066
0067
0068 PipelineModel model = pipeline.fit(trainingData);
0069
0070
0071 Dataset<Row> predictions = model.transform(testData);
0072
0073
0074 predictions.select("prediction", "label", "features").show(5);
0075
0076
0077 RegressionEvaluator evaluator = new RegressionEvaluator()
0078 .setLabelCol("label")
0079 .setPredictionCol("prediction")
0080 .setMetricName("rmse");
0081 double rmse = evaluator.evaluate(predictions);
0082 System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
0083
0084 RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]);
0085 System.out.println("Learned regression forest model:\n" + rfModel.toDebugString());
0086
0087
0088 spark.stop();
0089 }
0090 }