0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.classification;
0019
0020 import java.io.IOException;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
0029 import org.apache.spark.ml.feature.LabeledPoint;
0030 import org.apache.spark.ml.linalg.Vector;
0031 import org.apache.spark.sql.Dataset;
0032 import org.apache.spark.sql.Row;
0033
0034 public class JavaLogisticRegressionSuite extends SharedSparkSession {
0035
0036 private transient Dataset<Row> dataset;
0037
0038 private transient JavaRDD<LabeledPoint> datasetRDD;
0039 private double eps = 1e-5;
0040
0041 @Override
0042 public void setUp() throws IOException {
0043 super.setUp();
0044 List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
0045 datasetRDD = jsc.parallelize(points, 2);
0046 dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
0047 dataset.createOrReplaceTempView("dataset");
0048 }
0049
0050 @Test
0051 public void logisticRegressionDefaultParams() {
0052 LogisticRegression lr = new LogisticRegression();
0053 Assert.assertEquals("label", lr.getLabelCol());
0054 LogisticRegressionModel model = lr.fit(dataset);
0055 model.transform(dataset).createOrReplaceTempView("prediction");
0056 Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction");
0057 predictions.collectAsList();
0058
0059 Assert.assertEquals(0.5, model.getThreshold(), eps);
0060 Assert.assertEquals("features", model.getFeaturesCol());
0061 Assert.assertEquals("prediction", model.getPredictionCol());
0062 Assert.assertEquals("probability", model.getProbabilityCol());
0063 }
0064
0065 @Test
0066 public void logisticRegressionWithSetters() {
0067
0068 LogisticRegression lr = new LogisticRegression()
0069 .setMaxIter(10)
0070 .setRegParam(1.0)
0071 .setThreshold(0.6)
0072 .setProbabilityCol("myProbability");
0073 LogisticRegressionModel model = lr.fit(dataset);
0074 LogisticRegression parent = (LogisticRegression) model.parent();
0075 Assert.assertEquals(10, parent.getMaxIter());
0076 Assert.assertEquals(1.0, parent.getRegParam(), eps);
0077 Assert.assertEquals(0.4, parent.getThresholds()[0], eps);
0078 Assert.assertEquals(0.6, parent.getThresholds()[1], eps);
0079 Assert.assertEquals(0.6, parent.getThreshold(), eps);
0080 Assert.assertEquals(0.6, model.getThreshold(), eps);
0081
0082
0083 model.setThreshold(1.0);
0084 model.transform(dataset).createOrReplaceTempView("predAllZero");
0085 Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero");
0086 for (Row r : predAllZero.collectAsList()) {
0087 Assert.assertEquals(0.0, r.getDouble(0), eps);
0088 }
0089
0090 model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
0091 .createOrReplaceTempView("predNotAllZero");
0092 Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero");
0093 boolean foundNonZero = false;
0094 for (Row r : predNotAllZero.collectAsList()) {
0095 if (r.getDouble(0) != 0.0) foundNonZero = true;
0096 }
0097 Assert.assertTrue(foundNonZero);
0098
0099
0100 LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
0101 lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
0102 LogisticRegression parent2 = (LogisticRegression) model2.parent();
0103 Assert.assertEquals(5, parent2.getMaxIter());
0104 Assert.assertEquals(0.1, parent2.getRegParam(), eps);
0105 Assert.assertEquals(0.4, parent2.getThreshold(), eps);
0106 Assert.assertEquals(0.4, model2.getThreshold(), eps);
0107 Assert.assertEquals("theProb", model2.getProbabilityCol());
0108 }
0109
0110 @SuppressWarnings("unchecked")
0111 @Test
0112 public void logisticRegressionPredictorClassifierMethods() {
0113 LogisticRegression lr = new LogisticRegression();
0114 LogisticRegressionModel model = lr.fit(dataset);
0115 Assert.assertEquals(2, model.numClasses());
0116
0117 model.transform(dataset).createOrReplaceTempView("transformed");
0118 Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed");
0119 for (Row row : trans1.collectAsList()) {
0120 Vector raw = (Vector) row.get(0);
0121 Vector prob = (Vector) row.get(1);
0122 Assert.assertEquals(2, raw.size());
0123 Assert.assertEquals(2, prob.size());
0124 double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
0125 Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps);
0126 Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
0127 }
0128
0129 Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed");
0130 for (Row row : trans2.collectAsList()) {
0131 double pred = row.getDouble(0);
0132 Vector prob = (Vector) row.get(1);
0133 double probOfPred = prob.apply((int) pred);
0134 for (int i = 0; i < prob.size(); ++i) {
0135 Assert.assertTrue(probOfPred >= prob.apply(i));
0136 }
0137 }
0138 }
0139
0140 @Test
0141 public void logisticRegressionTrainingSummary() {
0142 LogisticRegression lr = new LogisticRegression();
0143 LogisticRegressionModel model = lr.fit(dataset);
0144
0145 LogisticRegressionTrainingSummary summary = model.summary();
0146 Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
0147 }
0148 }