0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.mllib;
0019
0020
0021 import java.util.Arrays;
0022
0023 import scala.Tuple2;
0024
0025 import org.apache.spark.api.java.*;
0026 import org.apache.spark.mllib.classification.LogisticRegressionModel;
0027 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
0028 import org.apache.spark.mllib.linalg.Vector;
0029 import org.apache.spark.mllib.linalg.Vectors;
0030 import org.apache.spark.mllib.optimization.*;
0031 import org.apache.spark.mllib.regression.LabeledPoint;
0032 import org.apache.spark.mllib.util.MLUtils;
0033 import org.apache.spark.SparkConf;
0034 import org.apache.spark.SparkContext;
0035
0036
0037 public class JavaLBFGSExample {
0038 public static void main(String[] args) {
0039 SparkConf conf = new SparkConf().setAppName("L-BFGS Example");
0040 SparkContext sc = new SparkContext(conf);
0041
0042
0043 String path = "data/mllib/sample_libsvm_data.txt";
0044 JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
0045 int numFeatures = data.take(1).get(0).features().size();
0046
0047
0048 JavaRDD<LabeledPoint> trainingInit = data.sample(false, 0.6, 11L);
0049 JavaRDD<LabeledPoint> test = data.subtract(trainingInit);
0050
0051
0052 JavaPairRDD<Object, Vector> training = data.mapToPair(p ->
0053 new Tuple2<>(p.label(), MLUtils.appendBias(p.features())));
0054 training.cache();
0055
0056
0057 int numCorrections = 10;
0058 double convergenceTol = 1e-4;
0059 int maxNumIterations = 20;
0060 double regParam = 0.1;
0061 Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]);
0062
0063 Tuple2<Vector, double[]> result = LBFGS.runLBFGS(
0064 training.rdd(),
0065 new LogisticGradient(),
0066 new SquaredL2Updater(),
0067 numCorrections,
0068 convergenceTol,
0069 maxNumIterations,
0070 regParam,
0071 initialWeightsWithIntercept);
0072 Vector weightsWithIntercept = result._1();
0073 double[] loss = result._2();
0074
0075 LogisticRegressionModel model = new LogisticRegressionModel(
0076 Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)),
0077 (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]);
0078
0079
0080 model.clearThreshold();
0081
0082
0083 JavaPairRDD<Object, Object> scoreAndLabels = test.mapToPair(p ->
0084 new Tuple2<>(model.predict(p.features()), p.label()));
0085
0086
0087 BinaryClassificationMetrics metrics =
0088 new BinaryClassificationMetrics(scoreAndLabels.rdd());
0089 double auROC = metrics.areaUnderROC();
0090
0091 System.out.println("Loss of each step in training process");
0092 for (double l : loss) {
0093 System.out.println(l);
0094 }
0095 System.out.println("Area under ROC = " + auROC);
0096
0097
0098 sc.stop();
0099 }
0100 }
0101