Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.mllib;
0019 
0020 // $example on$
0021 import java.util.Arrays;
0022 
0023 import scala.Tuple2;
0024 
0025 import org.apache.spark.api.java.*;
0026 import org.apache.spark.mllib.classification.LogisticRegressionModel;
0027 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
0028 import org.apache.spark.mllib.linalg.Vector;
0029 import org.apache.spark.mllib.linalg.Vectors;
0030 import org.apache.spark.mllib.optimization.*;
0031 import org.apache.spark.mllib.regression.LabeledPoint;
0032 import org.apache.spark.mllib.util.MLUtils;
0033 import org.apache.spark.SparkConf;
0034 import org.apache.spark.SparkContext;
0035 // $example off$
0036 
0037 public class JavaLBFGSExample {
0038   public static void main(String[] args) {
0039     SparkConf conf = new SparkConf().setAppName("L-BFGS Example");
0040     SparkContext sc = new SparkContext(conf);
0041 
0042     // $example on$
0043     String path = "data/mllib/sample_libsvm_data.txt";
0044     JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
0045     int numFeatures = data.take(1).get(0).features().size();
0046 
0047     // Split initial RDD into two... [60% training data, 40% testing data].
0048     JavaRDD<LabeledPoint> trainingInit = data.sample(false, 0.6, 11L);
0049     JavaRDD<LabeledPoint> test = data.subtract(trainingInit);
0050 
0051     // Append 1 into the training data as intercept.
0052     JavaPairRDD<Object, Vector> training = data.mapToPair(p ->
0053       new Tuple2<>(p.label(), MLUtils.appendBias(p.features())));
0054     training.cache();
0055 
0056     // Run training algorithm to build the model.
0057     int numCorrections = 10;
0058     double convergenceTol = 1e-4;
0059     int maxNumIterations = 20;
0060     double regParam = 0.1;
0061     Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]);
0062 
0063     Tuple2<Vector, double[]> result = LBFGS.runLBFGS(
0064       training.rdd(),
0065       new LogisticGradient(),
0066       new SquaredL2Updater(),
0067       numCorrections,
0068       convergenceTol,
0069       maxNumIterations,
0070       regParam,
0071       initialWeightsWithIntercept);
0072     Vector weightsWithIntercept = result._1();
0073     double[] loss = result._2();
0074 
0075     LogisticRegressionModel model = new LogisticRegressionModel(
0076       Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)),
0077       (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]);
0078 
0079     // Clear the default threshold.
0080     model.clearThreshold();
0081 
0082     // Compute raw scores on the test set.
0083     JavaPairRDD<Object, Object> scoreAndLabels = test.mapToPair(p ->
0084       new Tuple2<>(model.predict(p.features()), p.label()));
0085 
0086     // Get evaluation metrics.
0087     BinaryClassificationMetrics metrics =
0088       new BinaryClassificationMetrics(scoreAndLabels.rdd());
0089     double auROC = metrics.areaUnderROC();
0090 
0091     System.out.println("Loss of each step in training process");
0092     for (double l : loss) {
0093       System.out.println(l);
0094     }
0095     System.out.println("Area under ROC = " + auROC);
0096     // $example off$
0097 
0098     sc.stop();
0099   }
0100 }
0101