0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.tree;
0019
0020 import java.util.HashMap;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.mllib.regression.LabeledPoint;
0029 import org.apache.spark.mllib.tree.configuration.Algo;
0030 import org.apache.spark.mllib.tree.configuration.Strategy;
0031 import org.apache.spark.mllib.tree.impurity.Gini;
0032 import org.apache.spark.mllib.tree.model.DecisionTreeModel;
0033
0034 public class JavaDecisionTreeSuite extends SharedSparkSession {
0035
0036 private static int validatePrediction(
0037 List<LabeledPoint> validationData, DecisionTreeModel model) {
0038 int numCorrect = 0;
0039 for (LabeledPoint point : validationData) {
0040 Double prediction = model.predict(point.features());
0041 if (prediction == point.label()) {
0042 numCorrect++;
0043 }
0044 }
0045 return numCorrect;
0046 }
0047
0048 @Test
0049 public void runDTUsingConstructor() {
0050 List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
0051 JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
0052 HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
0053 categoricalFeaturesInfo.put(1, 2);
0054
0055 int maxDepth = 4;
0056 int numClasses = 2;
0057 int maxBins = 100;
0058 Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
0059 maxBins, categoricalFeaturesInfo);
0060
0061 DecisionTree learner = new DecisionTree(strategy);
0062 DecisionTreeModel model = learner.run(rdd.rdd());
0063
0064 int numCorrect = validatePrediction(arr, model);
0065 Assert.assertEquals(numCorrect, rdd.count());
0066 }
0067
0068 @Test
0069 public void runDTUsingStaticMethods() {
0070 List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
0071 JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
0072 HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
0073 categoricalFeaturesInfo.put(1, 2);
0074
0075 int maxDepth = 4;
0076 int numClasses = 2;
0077 int maxBins = 100;
0078 Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
0079 maxBins, categoricalFeaturesInfo);
0080
0081 DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
0082
0083
0084 JavaRDD<Double> predictions = model.predict(rdd.map(LabeledPoint::features));
0085
0086 int numCorrect = validatePrediction(arr, model);
0087 Assert.assertEquals(numCorrect, rdd.count());
0088 }
0089
0090 }