Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.classification;
0019 
0020 import java.util.HashMap;
0021 import java.util.Map;
0022 
0023 import org.junit.Test;
0024 
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.ml.feature.LabeledPoint;
0028 import org.apache.spark.ml.tree.impl.TreeTests;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 
0032 public class JavaDecisionTreeClassifierSuite extends SharedSparkSession {
0033 
0034   @Test
0035   public void runDT() {
0036     int nPoints = 20;
0037     double A = 2.0;
0038     double B = -1.5;
0039 
0040     JavaRDD<LabeledPoint> data = jsc.parallelize(
0041       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0042     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0043     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
0044 
0045     // This tests setters. Training with various options is tested in Scala.
0046     DecisionTreeClassifier dt = new DecisionTreeClassifier()
0047       .setMaxDepth(2)
0048       .setMaxBins(10)
0049       .setMinInstancesPerNode(5)
0050       .setMinInfoGain(0.0)
0051       .setMaxMemoryInMB(256)
0052       .setCacheNodeIds(false)
0053       .setCheckpointInterval(10)
0054       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
0055     for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
0056       dt.setImpurity(impurity);
0057     }
0058     DecisionTreeClassificationModel model = dt.fit(dataFrame);
0059 
0060     model.transform(dataFrame);
0061     model.numNodes();
0062     model.depth();
0063     model.toDebugString();
0064 
0065     /*
0066     // TODO: Add test once save/load are implemented.  SPARK-6725
0067     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
0068     String path = tempDir.toURI().toString();
0069     try {
0070       model3.save(sc.sc(), path);
0071       DecisionTreeClassificationModel sameModel =
0072         DecisionTreeClassificationModel.load(sc.sc(), path);
0073       TreeTests.checkEqual(model3, sameModel);
0074     } finally {
0075       Utils.deleteRecursively(tempDir);
0076     }
0077     */
0078   }
0079 }