Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.regression;
0019 
0020 import java.util.HashMap;
0021 import java.util.Map;
0022 
0023 import org.junit.Test;
0024 
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.ml.classification.LogisticRegressionSuite;
0028 import org.apache.spark.ml.feature.LabeledPoint;
0029 import org.apache.spark.ml.tree.impl.TreeTests;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032 
0033 
0034 public class JavaGBTRegressorSuite extends SharedSparkSession {
0035 
0036   @Test
0037   public void runDT() {
0038     int nPoints = 20;
0039     double A = 2.0;
0040     double B = -1.5;
0041 
0042     JavaRDD<LabeledPoint> data = jsc.parallelize(
0043       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
0044     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
0045     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
0046 
0047     GBTRegressor rf = new GBTRegressor()
0048       .setMaxDepth(2)
0049       .setMaxBins(10)
0050       .setMinInstancesPerNode(5)
0051       .setMinInfoGain(0.0)
0052       .setMaxMemoryInMB(256)
0053       .setCacheNodeIds(false)
0054       .setCheckpointInterval(10)
0055       .setSubsamplingRate(1.0)
0056       .setSeed(1234)
0057       .setMaxIter(3)
0058       .setStepSize(0.1)
0059       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
0060     for (String lossType : GBTRegressor.supportedLossTypes()) {
0061       rf.setLossType(lossType);
0062     }
0063     GBTRegressionModel model = rf.fit(dataFrame);
0064 
0065     model.transform(dataFrame);
0066     model.totalNumNodes();
0067     model.toDebugString();
0068     model.trees();
0069     model.treeWeights();
0070 
0071     /*
0072     // TODO: Add test once save/load are implemented.  SPARK-6725
0073     File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
0074     String path = tempDir.toURI().toString();
0075     try {
0076       model2.save(sc.sc(), path);
0077       GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path);
0078       TreeTests.checkEqual(model2, sameModel);
0079     } finally {
0080       Utils.deleteRecursively(tempDir);
0081     }
0082     */
0083   }
0084 }