Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.ml;
0019 
0020 // $example on$
0021 import org.apache.spark.ml.classification.LogisticRegression;
0022 import org.apache.spark.ml.classification.LogisticRegressionModel;
0023 import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
0024 import org.apache.spark.sql.Dataset;
0025 import org.apache.spark.sql.Row;
0026 import org.apache.spark.sql.SparkSession;
0027 // $example off$
0028 
0029 public class JavaMulticlassLogisticRegressionWithElasticNetExample {
0030     public static void main(String[] args) {
0031         SparkSession spark = SparkSession
0032                 .builder()
0033                 .appName("JavaMulticlassLogisticRegressionWithElasticNetExample")
0034                 .getOrCreate();
0035 
0036         // $example on$
0037         // Load training data
0038         Dataset<Row> training = spark.read().format("libsvm")
0039                 .load("data/mllib/sample_multiclass_classification_data.txt");
0040 
0041         LogisticRegression lr = new LogisticRegression()
0042                 .setMaxIter(10)
0043                 .setRegParam(0.3)
0044                 .setElasticNetParam(0.8);
0045 
0046         // Fit the model
0047         LogisticRegressionModel lrModel = lr.fit(training);
0048 
0049         // Print the coefficients and intercept for multinomial logistic regression
0050         System.out.println("Coefficients: \n"
0051                 + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector());
0052         LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
0053 
0054         // Obtain the loss per iteration.
0055         double[] objectiveHistory = trainingSummary.objectiveHistory();
0056         for (double lossPerIteration : objectiveHistory) {
0057             System.out.println(lossPerIteration);
0058         }
0059 
0060         // for multiclass, we can inspect metrics on a per-label basis
0061         System.out.println("False positive rate by label:");
0062         int i = 0;
0063         double[] fprLabel = trainingSummary.falsePositiveRateByLabel();
0064         for (double fpr : fprLabel) {
0065             System.out.println("label " + i + ": " + fpr);
0066             i++;
0067         }
0068 
0069         System.out.println("True positive rate by label:");
0070         i = 0;
0071         double[] tprLabel = trainingSummary.truePositiveRateByLabel();
0072         for (double tpr : tprLabel) {
0073             System.out.println("label " + i + ": " + tpr);
0074             i++;
0075         }
0076 
0077         System.out.println("Precision by label:");
0078         i = 0;
0079         double[] precLabel = trainingSummary.precisionByLabel();
0080         for (double prec : precLabel) {
0081             System.out.println("label " + i + ": " + prec);
0082             i++;
0083         }
0084 
0085         System.out.println("Recall by label:");
0086         i = 0;
0087         double[] recLabel = trainingSummary.recallByLabel();
0088         for (double rec : recLabel) {
0089             System.out.println("label " + i + ": " + rec);
0090             i++;
0091         }
0092 
0093         System.out.println("F-measure by label:");
0094         i = 0;
0095         double[] fLabel = trainingSummary.fMeasureByLabel();
0096         for (double f : fLabel) {
0097             System.out.println("label " + i + ": " + f);
0098             i++;
0099         }
0100 
0101         double accuracy = trainingSummary.accuracy();
0102         double falsePositiveRate = trainingSummary.weightedFalsePositiveRate();
0103         double truePositiveRate = trainingSummary.weightedTruePositiveRate();
0104         double fMeasure = trainingSummary.weightedFMeasure();
0105         double precision = trainingSummary.weightedPrecision();
0106         double recall = trainingSummary.weightedRecall();
0107         System.out.println("Accuracy: " + accuracy);
0108         System.out.println("FPR: " + falsePositiveRate);
0109         System.out.println("TPR: " + truePositiveRate);
0110         System.out.println("F-measure: " + fMeasure);
0111         System.out.println("Precision: " + precision);
0112         System.out.println("Recall: " + recall);
0113         // $example off$
0114 
0115         spark.stop();
0116     }
0117 }