Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # $example on$
0019 from pyspark.mllib.classification import LogisticRegressionWithLBFGS
0020 from pyspark.mllib.util import MLUtils
0021 from pyspark.mllib.evaluation import MulticlassMetrics
0022 # $example off$
0023 
0024 from pyspark import SparkContext
0025 
0026 if __name__ == "__main__":
0027     sc = SparkContext(appName="MultiClassMetricsExample")
0028 
0029     # Several of the methods available in scala are currently missing from pyspark
0030     # $example on$
0031     # Load training data in LIBSVM format
0032     data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt")
0033 
0034     # Split data into training (60%) and test (40%)
0035     training, test = data.randomSplit([0.6, 0.4], seed=11)
0036     training.cache()
0037 
0038     # Run training algorithm to build the model
0039     model = LogisticRegressionWithLBFGS.train(training, numClasses=3)
0040 
0041     # Compute raw scores on the test set
0042     predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label))
0043 
0044     # Instantiate metrics object
0045     metrics = MulticlassMetrics(predictionAndLabels)
0046 
0047     # Overall statistics
0048     precision = metrics.precision(1.0)
0049     recall = metrics.recall(1.0)
0050     f1Score = metrics.fMeasure(1.0)
0051     print("Summary Stats")
0052     print("Precision = %s" % precision)
0053     print("Recall = %s" % recall)
0054     print("F1 Score = %s" % f1Score)
0055 
0056     # Statistics by class
0057     labels = data.map(lambda lp: lp.label).distinct().collect()
0058     for label in sorted(labels):
0059         print("Class %s precision = %s" % (label, metrics.precision(label)))
0060         print("Class %s recall = %s" % (label, metrics.recall(label)))
0061         print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)))
0062 
0063     # Weighted stats
0064     print("Weighted recall = %s" % metrics.weightedRecall)
0065     print("Weighted precision = %s" % metrics.weightedPrecision)
0066     print("Weighted F(1) Score = %s" % metrics.weightedFMeasure())
0067     print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5))
0068     print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate)
0069     # $example off$