Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.classification;
0019 
0020 import java.util.HashMap;
0021 import java.util.Map;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.ml.feature.LabeledPoint;
0029 import org.apache.spark.ml.linalg.Vector;
0030 import org.apache.spark.ml.tree.impl.TreeTests;
0031 import org.apache.spark.sql.Dataset;
0032 import org.apache.spark.sql.Row;
0033 
0034 public class JavaRandomForestClassifierSuite extends SharedSparkSession {
0035 
0036   @Test
0037   public void runDT() {
0038     int nPoints = 20;
0039     double A = 2.0;
0040     double B = -1.5;
0041 
0042     JavaRDD<LabeledPoint> data = jsc.parallelize(
0043       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0044     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0045     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
0046 
0047     // This tests setters. Training with various options is tested in Scala.
0048     RandomForestClassifier rf = new RandomForestClassifier()
0049       .setMaxDepth(2)
0050       .setMaxBins(10)
0051       .setMinInstancesPerNode(5)
0052       .setMinInfoGain(0.0)
0053       .setMaxMemoryInMB(256)
0054       .setCacheNodeIds(false)
0055       .setCheckpointInterval(10)
0056       .setSubsamplingRate(1.0)
0057       .setSeed(1234)
0058       .setNumTrees(3)
0059       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
0060     for (String impurity : RandomForestClassifier.supportedImpurities()) {
0061       rf.setImpurity(impurity);
0062     }
0063     for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) {
0064       rf.setFeatureSubsetStrategy(featureSubsetStrategy);
0065     }
0066     String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
0067     for (String strategy : realStrategies) {
0068       rf.setFeatureSubsetStrategy(strategy);
0069     }
0070     String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
0071     for (String strategy : integerStrategies) {
0072       rf.setFeatureSubsetStrategy(strategy);
0073     }
0074     String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
0075     for (String strategy : invalidStrategies) {
0076       try {
0077         rf.setFeatureSubsetStrategy(strategy);
0078         Assert.fail("Expected exception to be thrown for invalid strategies");
0079       } catch (Exception e) {
0080         Assert.assertTrue(e instanceof IllegalArgumentException);
0081       }
0082     }
0083 
0084     RandomForestClassificationModel model = rf.fit(dataFrame);
0085 
0086     model.transform(dataFrame);
0087     model.totalNumNodes();
0088     model.toDebugString();
0089     model.trees();
0090     model.treeWeights();
0091     Vector importances = model.featureImportances();
0092 
0093     /*
0094     // TODO: Add test once save/load are implemented.  SPARK-6725
0095     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
0096     String path = tempDir.toURI().toString();
0097     try {
0098       model3.save(sc.sc(), path);
0099       RandomForestClassificationModel sameModel =
0100           RandomForestClassificationModel.load(sc.sc(), path);
0101       TreeTests.checkEqual(model3, sameModel);
0102     } finally {
0103       Utils.deleteRecursively(tempDir);
0104     }
0105     */
0106   }
0107 }