0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.tuning;
0019
0020 import java.io.IOException;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.ml.classification.LogisticRegression;
0028 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
0029 import org.apache.spark.ml.feature.LabeledPoint;
0030 import org.apache.spark.ml.param.ParamMap;
0031 import org.apache.spark.sql.Dataset;
0032 import org.apache.spark.sql.Row;
0033 import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
0034
0035
0036 public class JavaCrossValidatorSuite extends SharedSparkSession {
0037
0038 private transient Dataset<Row> dataset;
0039
0040 @Override
0041 public void setUp() throws IOException {
0042 super.setUp();
0043 List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
0044 dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
0045 }
0046
0047 @Test
0048 public void crossValidationWithLogisticRegression() {
0049 LogisticRegression lr = new LogisticRegression();
0050 ParamMap[] lrParamMaps = new ParamGridBuilder()
0051 .addGrid(lr.regParam(), new double[]{0.001, 1000.0})
0052 .addGrid(lr.maxIter(), new int[]{0, 10})
0053 .build();
0054 BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
0055 CrossValidator cv = new CrossValidator()
0056 .setEstimator(lr)
0057 .setEstimatorParamMaps(lrParamMaps)
0058 .setEvaluator(eval)
0059 .setNumFolds(3);
0060 CrossValidatorModel cvModel = cv.fit(dataset);
0061 LogisticRegression parent = (LogisticRegression) cvModel.bestModel().parent();
0062 Assert.assertEquals(0.001, parent.getRegParam(), 0.0);
0063 Assert.assertEquals(10, parent.getMaxIter());
0064 }
0065 }