Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.regression;
0019 
0020 import java.util.HashMap;
0021 import java.util.Map;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.ml.classification.LogisticRegressionSuite;
0029 import org.apache.spark.ml.feature.LabeledPoint;
0030 import org.apache.spark.ml.linalg.Vector;
0031 import org.apache.spark.ml.tree.impl.TreeTests;
0032 import org.apache.spark.sql.Dataset;
0033 import org.apache.spark.sql.Row;
0034 
0035 
0036 public class JavaRandomForestRegressorSuite extends SharedSparkSession {
0037 
0038   @Test
0039   public void runDT() {
0040     int nPoints = 20;
0041     double A = 2.0;
0042     double B = -1.5;
0043 
0044     JavaRDD<LabeledPoint> data = jsc.parallelize(
0045       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0046     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0047     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
0048 
0049     // This tests setters. Training with various options is tested in Scala.
0050     RandomForestRegressor rf = new RandomForestRegressor()
0051       .setMaxDepth(2)
0052       .setMaxBins(10)
0053       .setMinInstancesPerNode(5)
0054       .setMinInfoGain(0.0)
0055       .setMaxMemoryInMB(256)
0056       .setCacheNodeIds(false)
0057       .setCheckpointInterval(10)
0058       .setSubsamplingRate(1.0)
0059       .setSeed(1234)
0060       .setNumTrees(3)
0061       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
0062     for (String impurity : RandomForestRegressor.supportedImpurities()) {
0063       rf.setImpurity(impurity);
0064     }
0065     for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) {
0066       rf.setFeatureSubsetStrategy(featureSubsetStrategy);
0067     }
0068     String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
0069     for (String strategy : realStrategies) {
0070       rf.setFeatureSubsetStrategy(strategy);
0071     }
0072     String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
0073     for (String strategy : integerStrategies) {
0074       rf.setFeatureSubsetStrategy(strategy);
0075     }
0076     String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
0077     for (String strategy : invalidStrategies) {
0078       try {
0079         rf.setFeatureSubsetStrategy(strategy);
0080         Assert.fail("Expected exception to be thrown for invalid strategies");
0081       } catch (Exception e) {
0082         Assert.assertTrue(e instanceof IllegalArgumentException);
0083       }
0084     }
0085 
0086     RandomForestRegressionModel model = rf.fit(dataFrame);
0087 
0088     model.transform(dataFrame);
0089     model.totalNumNodes();
0090     model.toDebugString();
0091     model.trees();
0092     model.treeWeights();
0093     Vector importances = model.featureImportances();
0094 
0095     /*
0096     // TODO: Add test once save/load are implemented.   SPARK-6725
0097     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
0098     String path = tempDir.toURI().toString();
0099     try {
0100       model2.save(sc.sc(), path);
0101       RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path);
0102       TreeTests.checkEqual(model2, sameModel);
0103     } finally {
0104       Utils.deleteRecursively(tempDir);
0105     }
0106     */
0107   }
0108 }