0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.regression;
0019
0020 import java.util.HashMap;
0021 import java.util.Map;
0022
0023 import org.junit.Test;
0024
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.ml.classification.LogisticRegressionSuite;
0028 import org.apache.spark.ml.feature.LabeledPoint;
0029 import org.apache.spark.ml.tree.impl.TreeTests;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032
0033
0034 public class JavaGBTRegressorSuite extends SharedSparkSession {
0035
0036 @Test
0037 public void runDT() {
0038 int nPoints = 20;
0039 double A = 2.0;
0040 double B = -1.5;
0041
0042 JavaRDD<LabeledPoint> data = jsc.parallelize(
0043 LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0044 Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0045 Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
0046
0047 GBTRegressor rf = new GBTRegressor()
0048 .setMaxDepth(2)
0049 .setMaxBins(10)
0050 .setMinInstancesPerNode(5)
0051 .setMinInfoGain(0.0)
0052 .setMaxMemoryInMB(256)
0053 .setCacheNodeIds(false)
0054 .setCheckpointInterval(10)
0055 .setSubsamplingRate(1.0)
0056 .setSeed(1234)
0057 .setMaxIter(3)
0058 .setStepSize(0.1)
0059 .setMaxDepth(2);
0060 for (String lossType : GBTRegressor.supportedLossTypes()) {
0061 rf.setLossType(lossType);
0062 }
0063 GBTRegressionModel model = rf.fit(dataFrame);
0064
0065 model.transform(dataFrame);
0066 model.totalNumNodes();
0067 model.toDebugString();
0068 model.trees();
0069 model.treeWeights();
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083 }
0084 }