0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.regression;
0019
0020 import java.util.HashMap;
0021 import java.util.Map;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.ml.classification.LogisticRegressionSuite;
0029 import org.apache.spark.ml.feature.LabeledPoint;
0030 import org.apache.spark.ml.linalg.Vector;
0031 import org.apache.spark.ml.tree.impl.TreeTests;
0032 import org.apache.spark.sql.Dataset;
0033 import org.apache.spark.sql.Row;
0034
0035
0036 public class JavaRandomForestRegressorSuite extends SharedSparkSession {
0037
0038 @Test
0039 public void runDT() {
0040 int nPoints = 20;
0041 double A = 2.0;
0042 double B = -1.5;
0043
0044 JavaRDD<LabeledPoint> data = jsc.parallelize(
0045 LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0046 Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0047 Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
0048
0049
0050 RandomForestRegressor rf = new RandomForestRegressor()
0051 .setMaxDepth(2)
0052 .setMaxBins(10)
0053 .setMinInstancesPerNode(5)
0054 .setMinInfoGain(0.0)
0055 .setMaxMemoryInMB(256)
0056 .setCacheNodeIds(false)
0057 .setCheckpointInterval(10)
0058 .setSubsamplingRate(1.0)
0059 .setSeed(1234)
0060 .setNumTrees(3)
0061 .setMaxDepth(2);
0062 for (String impurity : RandomForestRegressor.supportedImpurities()) {
0063 rf.setImpurity(impurity);
0064 }
0065 for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
0066 rf.setFeatureSubsetStrategy(featureSubsetStrategy);
0067 }
0068 String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
0069 for (String strategy : realStrategies) {
0070 rf.setFeatureSubsetStrategy(strategy);
0071 }
0072 String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
0073 for (String strategy : integerStrategies) {
0074 rf.setFeatureSubsetStrategy(strategy);
0075 }
0076 String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
0077 for (String strategy : invalidStrategies) {
0078 try {
0079 rf.setFeatureSubsetStrategy(strategy);
0080 Assert.fail("Expected exception to be thrown for invalid strategies");
0081 } catch (Exception e) {
0082 Assert.assertTrue(e instanceof IllegalArgumentException);
0083 }
0084 }
0085
0086 RandomForestRegressionModel model = rf.fit(dataFrame);
0087
0088 model.transform(dataFrame);
0089 model.totalNumNodes();
0090 model.toDebugString();
0091 model.trees();
0092 model.treeWeights();
0093 Vector importances = model.featureImportances();
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107 }
0108 }