0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
0022 import org.apache.spark.ml.classification.LogisticRegression;
0023 import org.apache.spark.ml.classification.LogisticRegressionModel;
0024 import org.apache.spark.sql.Dataset;
0025 import org.apache.spark.sql.Row;
0026 import org.apache.spark.sql.SparkSession;
0027 import org.apache.spark.sql.functions;
0028
0029
0030 public class JavaLogisticRegressionSummaryExample {
0031 public static void main(String[] args) {
0032 SparkSession spark = SparkSession
0033 .builder()
0034 .appName("JavaLogisticRegressionSummaryExample")
0035 .getOrCreate();
0036
0037
0038 Dataset<Row> training = spark.read().format("libsvm")
0039 .load("data/mllib/sample_libsvm_data.txt");
0040
0041 LogisticRegression lr = new LogisticRegression()
0042 .setMaxIter(10)
0043 .setRegParam(0.3)
0044 .setElasticNetParam(0.8);
0045
0046
0047 LogisticRegressionModel lrModel = lr.fit(training);
0048
0049
0050
0051
0052 BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary();
0053
0054
0055 double[] objectiveHistory = trainingSummary.objectiveHistory();
0056 for (double lossPerIteration : objectiveHistory) {
0057 System.out.println(lossPerIteration);
0058 }
0059
0060
0061 Dataset<Row> roc = trainingSummary.roc();
0062 roc.show();
0063 roc.select("FPR").show();
0064 System.out.println(trainingSummary.areaUnderROC());
0065
0066
0067
0068 Dataset<Row> fMeasure = trainingSummary.fMeasureByThreshold();
0069 double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
0070 double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
0071 .select("threshold").head().getDouble(0);
0072 lrModel.setThreshold(bestThreshold);
0073
0074
0075 spark.stop();
0076 }
0077 }