Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.tree;
0019 
0020 import java.util.HashMap;
0021 import java.util.List;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.mllib.regression.LabeledPoint;
0029 import org.apache.spark.mllib.tree.configuration.Algo;
0030 import org.apache.spark.mllib.tree.configuration.Strategy;
0031 import org.apache.spark.mllib.tree.impurity.Gini;
0032 import org.apache.spark.mllib.tree.model.DecisionTreeModel;
0033 
0034 public class JavaDecisionTreeSuite extends SharedSparkSession {
0035 
0036   private static int validatePrediction(
0037       List<LabeledPoint> validationData, DecisionTreeModel model) {
0038     int numCorrect = 0;
0039     for (LabeledPoint point : validationData) {
0040       Double prediction = model.predict(point.features());
0041       if (prediction == point.label()) {
0042         numCorrect++;
0043       }
0044     }
0045     return numCorrect;
0046   }
0047 
0048   @Test
0049   public void runDTUsingConstructor() {
0050     List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
0051     JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
0052     HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
0053     categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
0054 
0055     int maxDepth = 4;
0056     int numClasses = 2;
0057     int maxBins = 100;
0058     Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
0059       maxBins, categoricalFeaturesInfo);
0060 
0061     DecisionTree learner = new DecisionTree(strategy);
0062     DecisionTreeModel model = learner.run(rdd.rdd());
0063 
0064     int numCorrect = validatePrediction(arr, model);
0065     Assert.assertEquals(numCorrect, rdd.count());
0066   }
0067 
0068   @Test
0069   public void runDTUsingStaticMethods() {
0070     List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
0071     JavaRDD<LabeledPoint> rdd = jsc.parallelize(arr);
0072     HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
0073     categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
0074 
0075     int maxDepth = 4;
0076     int numClasses = 2;
0077     int maxBins = 100;
0078     Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
0079       maxBins, categoricalFeaturesInfo);
0080 
0081     DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
0082 
0083     // java compatibility test
0084     JavaRDD<Double> predictions = model.predict(rdd.map(LabeledPoint::features));
0085 
0086     int numCorrect = validatePrediction(arr, model);
0087     Assert.assertEquals(numCorrect, rdd.count());
0088   }
0089 
0090 }