Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.regression;
0019 
0020 import java.util.HashMap;
0021 import java.util.Map;
0022 
0023 import org.junit.Test;
0024 
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.ml.classification.LogisticRegressionSuite;
0028 import org.apache.spark.ml.feature.LabeledPoint;
0029 import org.apache.spark.ml.tree.impl.TreeTests;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032 
0033 
0034 public class JavaDecisionTreeRegressorSuite extends SharedSparkSession {
0035 
0036   @Test
0037   public void runDT() {
0038     int nPoints = 20;
0039     double A = 2.0;
0040     double B = -1.5;
0041 
0042     JavaRDD<LabeledPoint> data = jsc.parallelize(
0043       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0044     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0045     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
0046 
0047     // This tests setters. Training with various options is tested in Scala.
0048     DecisionTreeRegressor dt = new DecisionTreeRegressor()
0049       .setMaxDepth(2)
0050       .setMaxBins(10)
0051       .setMinInstancesPerNode(5)
0052       .setMinInfoGain(0.0)
0053       .setMaxMemoryInMB(256)
0054       .setCacheNodeIds(false)
0055       .setCheckpointInterval(10)
0056       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
0057     for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
0058       dt.setImpurity(impurity);
0059     }
0060     DecisionTreeRegressionModel model = dt.fit(dataFrame);
0061 
0062     model.transform(dataFrame);
0063     model.numNodes();
0064     model.depth();
0065     model.toDebugString();
0066 
0067     /*
0068     // TODO: Add test once save/load are implemented.   SPARK-6725
0069     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
0070     String path = tempDir.toURI().toString();
0071     try {
0072       model2.save(sc.sc(), path);
0073       DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path);
0074       TreeTests.checkEqual(model2, sameModel);
0075     } finally {
0076       Utils.deleteRecursively(tempDir);
0077     }
0078     */
0079   }
0080 }