Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.classification;
0019 
0020 import java.util.HashMap;
0021 import java.util.Map;
0022 
0023 import org.junit.Test;
0024 
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.ml.feature.LabeledPoint;
0028 import org.apache.spark.ml.tree.impl.TreeTests;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 
0032 public class JavaGBTClassifierSuite extends SharedSparkSession {
0033 
0034   @Test
0035   public void runDT() {
0036     int nPoints = 20;
0037     double A = 2.0;
0038     double B = -1.5;
0039 
0040     JavaRDD<LabeledPoint> data = jsc.parallelize(
0041       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0042     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0043     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
0044 
0045     // This tests setters. Training with various options is tested in Scala.
0046     GBTClassifier rf = new GBTClassifier()
0047       .setMaxDepth(2)
0048       .setMaxBins(10)
0049       .setMinInstancesPerNode(5)
0050       .setMinInfoGain(0.0)
0051       .setMaxMemoryInMB(256)
0052       .setCacheNodeIds(false)
0053       .setCheckpointInterval(10)
0054       .setSubsamplingRate(1.0)
0055       .setSeed(1234)
0056       .setMaxIter(3)
0057       .setStepSize(0.1)
0058       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
0059     for (String lossType : GBTClassifier.supportedLossTypes()) {
0060       rf.setLossType(lossType);
0061     }
0062     GBTClassificationModel model = rf.fit(dataFrame);
0063 
0064     model.transform(dataFrame);
0065     model.totalNumNodes();
0066     model.toDebugString();
0067     model.trees();
0068     model.treeWeights();
0069 
0070     /*
0071     // TODO: Add test once save/load are implemented.  SPARK-6725
0072     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
0073     String path = tempDir.toURI().toString();
0074     try {
0075       model3.save(sc.sc(), path);
0076       GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path);
0077       TreeTests.checkEqual(model3, sameModel);
0078     } finally {
0079       Utils.deleteRecursively(tempDir);
0080     }
0081     */
0082   }
0083 }