Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.ml.classification;
0019 
0020 import java.io.IOException;
0021 import java.util.List;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
0029 import org.apache.spark.ml.feature.LabeledPoint;
0030 import org.apache.spark.ml.linalg.Vector;
0031 import org.apache.spark.sql.Dataset;
0032 import org.apache.spark.sql.Row;
0033 
0034 public class JavaLogisticRegressionSuite extends SharedSparkSession {
0035 
0036   private transient Dataset<Row> dataset;
0037 
0038   private transient JavaRDD<LabeledPoint> datasetRDD;
0039   private double eps = 1e-5;
0040 
0041   @Override
0042   public void setUp() throws IOException {
0043     super.setUp();
0044     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
0045     datasetRDD = jsc.parallelize(points, 2);
0046     dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
0047     dataset.createOrReplaceTempView("dataset");
0048   }
0049 
0050   @Test
0051   public void logisticRegressionDefaultParams() {
0052     LogisticRegression lr = new LogisticRegression();
0053     Assert.assertEquals("label", lr.getLabelCol());
0054     LogisticRegressionModel model = lr.fit(dataset);
0055     model.transform(dataset).createOrReplaceTempView("prediction");
0056     Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
0057     predictions.collectAsList();
0058     // Check defaults
0059     Assert.assertEquals(0.5, model.getThreshold(), eps);
0060     Assert.assertEquals("features", model.getFeaturesCol());
0061     Assert.assertEquals("prediction", model.getPredictionCol());
0062     Assert.assertEquals("probability", model.getProbabilityCol());
0063   }
0064 
0065   @Test
0066   public void logisticRegressionWithSetters() {
0067     // Set params, train, and check as many params as we can.
0068     LogisticRegression lr = new LogisticRegression()
0069       .setMaxIter(10)
0070       .setRegParam(1.0)
0071       .setThreshold(0.6)
0072       .setProbabilityCol("myProbability");
0073     LogisticRegressionModel model = lr.fit(dataset);
0074     LogisticRegression parent = (LogisticRegression) model.parent();
0075     Assert.assertEquals(10, parent.getMaxIter());
0076     Assert.assertEquals(1.0, parent.getRegParam(), eps);
0077     Assert.assertEquals(0.4, parent.getThresholds()[0], eps);
0078     Assert.assertEquals(0.6, parent.getThresholds()[1], eps);
0079     Assert.assertEquals(0.6, parent.getThreshold(), eps);
0080     Assert.assertEquals(0.6, model.getThreshold(), eps);
0081 
0082     // Modify model params, and check that the params worked.
0083     model.setThreshold(1.0);
0084     model.transform(dataset).createOrReplaceTempView("predAllZero");
0085     Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
0086     for (Row r : predAllZero.collectAsList()) {
0087       Assert.assertEquals(0.0, r.getDouble(0), eps);
0088     }
0089     // Call transform with params, and check that the params worked.
0090     model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
0091       .createOrReplaceTempView("predNotAllZero");
0092     Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
0093     boolean foundNonZero = false;
0094     for (Row r : predNotAllZero.collectAsList()) {
0095       if (r.getDouble(0) != 0.0) foundNonZero = true;
0096     }
0097     Assert.assertTrue(foundNonZero);
0098 
0099     // Call fit() with new params, and check as many params as we can.
0100     LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
0101       lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
0102     LogisticRegression parent2 = (LogisticRegression) model2.parent();
0103     Assert.assertEquals(5, parent2.getMaxIter());
0104     Assert.assertEquals(0.1, parent2.getRegParam(), eps);
0105     Assert.assertEquals(0.4, parent2.getThreshold(), eps);
0106     Assert.assertEquals(0.4, model2.getThreshold(), eps);
0107     Assert.assertEquals("theProb", model2.getProbabilityCol());
0108   }
0109 
0110   @SuppressWarnings("unchecked")
0111   @Test
0112   public void logisticRegressionPredictorClassifierMethods() {
0113     LogisticRegression lr = new LogisticRegression();
0114     LogisticRegressionModel model = lr.fit(dataset);
0115     Assert.assertEquals(2, model.numClasses());
0116 
0117     model.transform(dataset).createOrReplaceTempView("transformed");
0118     Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
0119     for (Row row : trans1.collectAsList()) {
0120       Vector raw = (Vector) row.get(0);
0121       Vector prob = (Vector) row.get(1);
0122       Assert.assertEquals(2, raw.size());
0123       Assert.assertEquals(2, prob.size());
0124       double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
0125       Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps);
0126       Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
0127     }
0128 
0129     Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
0130     for (Row row : trans2.collectAsList()) {
0131       double pred = row.getDouble(0);
0132       Vector prob = (Vector) row.get(1);
0133       double probOfPred = prob.apply((int) pred);
0134       for (int i = 0; i < prob.size(); ++i) {
0135         Assert.assertTrue(probOfPred >= prob.apply(i));
0136       }
0137     }
0138   }
0139 
0140   @Test
0141   public void logisticRegressionTrainingSummary() {
0142     LogisticRegression lr = new LogisticRegression();
0143     LogisticRegressionModel model = lr.fit(dataset);
0144 
0145     LogisticRegressionTrainingSummary summary = model.summary();
0146     Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
0147   }
0148 }