Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 
0020 from pyspark import since
0021 from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
0022 from pyspark.sql import SQLContext
0023 from pyspark.sql.types import ArrayType, StructField, StructType, DoubleType
0024 
0025 __all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
0026            'MulticlassMetrics', 'RankingMetrics']
0027 
0028 
0029 class BinaryClassificationMetrics(JavaModelWrapper):
0030     """
0031     Evaluator for binary classification.
0032 
0033     :param scoreAndLabels: an RDD of score, label and optional weight.
0034 
0035     >>> scoreAndLabels = sc.parallelize([
0036     ...     (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
0037     >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
0038     >>> metrics.areaUnderROC
0039     0.70...
0040     >>> metrics.areaUnderPR
0041     0.83...
0042     >>> metrics.unpersist()
0043     >>> scoreAndLabelsWithOptWeight = sc.parallelize([
0044     ...     (0.1, 0.0, 1.0), (0.1, 1.0, 0.4), (0.4, 0.0, 0.2), (0.6, 0.0, 0.6), (0.6, 1.0, 0.9),
0045     ...     (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2)
0046     >>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight)
0047     >>> metrics.areaUnderROC
0048     0.79...
0049     >>> metrics.areaUnderPR
0050     0.88...
0051 
0052     .. versionadded:: 1.4.0
0053     """
0054 
0055     def __init__(self, scoreAndLabels):
0056         sc = scoreAndLabels.ctx
0057         sql_ctx = SQLContext.getOrCreate(sc)
0058         numCol = len(scoreAndLabels.first())
0059         schema = StructType([
0060             StructField("score", DoubleType(), nullable=False),
0061             StructField("label", DoubleType(), nullable=False)])
0062         if numCol == 3:
0063             schema.add("weight", DoubleType(), False)
0064         df = sql_ctx.createDataFrame(scoreAndLabels, schema=schema)
0065         java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
0066         java_model = java_class(df._jdf)
0067         super(BinaryClassificationMetrics, self).__init__(java_model)
0068 
0069     @property
0070     @since('1.4.0')
0071     def areaUnderROC(self):
0072         """
0073         Computes the area under the receiver operating characteristic
0074         (ROC) curve.
0075         """
0076         return self.call("areaUnderROC")
0077 
0078     @property
0079     @since('1.4.0')
0080     def areaUnderPR(self):
0081         """
0082         Computes the area under the precision-recall curve.
0083         """
0084         return self.call("areaUnderPR")
0085 
0086     @since('1.4.0')
0087     def unpersist(self):
0088         """
0089         Unpersists intermediate RDDs used in the computation.
0090         """
0091         self.call("unpersist")
0092 
0093 
0094 class RegressionMetrics(JavaModelWrapper):
0095     """
0096     Evaluator for regression.
0097 
0098     :param predictionAndObservations: an RDD of prediction, observation and optional weight.
0099 
0100     >>> predictionAndObservations = sc.parallelize([
0101     ...     (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
0102     >>> metrics = RegressionMetrics(predictionAndObservations)
0103     >>> metrics.explainedVariance
0104     8.859...
0105     >>> metrics.meanAbsoluteError
0106     0.5...
0107     >>> metrics.meanSquaredError
0108     0.37...
0109     >>> metrics.rootMeanSquaredError
0110     0.61...
0111     >>> metrics.r2
0112     0.94...
0113     >>> predictionAndObservationsWithOptWeight = sc.parallelize([
0114     ...     (2.5, 3.0, 0.5), (0.0, -0.5, 1.0), (2.0, 2.0, 0.3), (8.0, 7.0, 0.9)])
0115     >>> metrics = RegressionMetrics(predictionAndObservationsWithOptWeight)
0116     >>> metrics.rootMeanSquaredError
0117     0.68...
0118 
0119     .. versionadded:: 1.4.0
0120     """
0121 
0122     def __init__(self, predictionAndObservations):
0123         sc = predictionAndObservations.ctx
0124         sql_ctx = SQLContext.getOrCreate(sc)
0125         numCol = len(predictionAndObservations.first())
0126         schema = StructType([
0127             StructField("prediction", DoubleType(), nullable=False),
0128             StructField("observation", DoubleType(), nullable=False)])
0129         if numCol == 3:
0130             schema.add("weight", DoubleType(), False)
0131         df = sql_ctx.createDataFrame(predictionAndObservations, schema=schema)
0132         java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
0133         java_model = java_class(df._jdf)
0134         super(RegressionMetrics, self).__init__(java_model)
0135 
0136     @property
0137     @since('1.4.0')
0138     def explainedVariance(self):
0139         r"""
0140         Returns the explained variance regression score.
0141         explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}`
0142         """
0143         return self.call("explainedVariance")
0144 
0145     @property
0146     @since('1.4.0')
0147     def meanAbsoluteError(self):
0148         """
0149         Returns the mean absolute error, which is a risk function corresponding to the
0150         expected value of the absolute error loss or l1-norm loss.
0151         """
0152         return self.call("meanAbsoluteError")
0153 
0154     @property
0155     @since('1.4.0')
0156     def meanSquaredError(self):
0157         """
0158         Returns the mean squared error, which is a risk function corresponding to the
0159         expected value of the squared error loss or quadratic loss.
0160         """
0161         return self.call("meanSquaredError")
0162 
0163     @property
0164     @since('1.4.0')
0165     def rootMeanSquaredError(self):
0166         """
0167         Returns the root mean squared error, which is defined as the square root of
0168         the mean squared error.
0169         """
0170         return self.call("rootMeanSquaredError")
0171 
0172     @property
0173     @since('1.4.0')
0174     def r2(self):
0175         """
0176         Returns R^2^, the coefficient of determination.
0177         """
0178         return self.call("r2")
0179 
0180 
0181 class MulticlassMetrics(JavaModelWrapper):
0182     """
0183     Evaluator for multiclass classification.
0184 
0185     :param predictionAndLabels: an RDD of prediction, label, optional weight
0186      and optional probability.
0187 
0188     >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
0189     ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
0190     >>> metrics = MulticlassMetrics(predictionAndLabels)
0191     >>> metrics.confusionMatrix().toArray()
0192     array([[ 2.,  1.,  1.],
0193            [ 1.,  3.,  0.],
0194            [ 0.,  0.,  1.]])
0195     >>> metrics.falsePositiveRate(0.0)
0196     0.2...
0197     >>> metrics.precision(1.0)
0198     0.75...
0199     >>> metrics.recall(2.0)
0200     1.0...
0201     >>> metrics.fMeasure(0.0, 2.0)
0202     0.52...
0203     >>> metrics.accuracy
0204     0.66...
0205     >>> metrics.weightedFalsePositiveRate
0206     0.19...
0207     >>> metrics.weightedPrecision
0208     0.68...
0209     >>> metrics.weightedRecall
0210     0.66...
0211     >>> metrics.weightedFMeasure()
0212     0.66...
0213     >>> metrics.weightedFMeasure(2.0)
0214     0.65...
0215     >>> predAndLabelsWithOptWeight = sc.parallelize([(0.0, 0.0, 1.0), (0.0, 1.0, 1.0),
0216     ...      (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
0217     ...      (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)])
0218     >>> metrics = MulticlassMetrics(predAndLabelsWithOptWeight)
0219     >>> metrics.confusionMatrix().toArray()
0220     array([[ 2.,  1.,  1.],
0221            [ 1.,  3.,  0.],
0222            [ 0.,  0.,  1.]])
0223     >>> metrics.falsePositiveRate(0.0)
0224     0.2...
0225     >>> metrics.precision(1.0)
0226     0.75...
0227     >>> metrics.recall(2.0)
0228     1.0...
0229     >>> metrics.fMeasure(0.0, 2.0)
0230     0.52...
0231     >>> metrics.accuracy
0232     0.66...
0233     >>> metrics.weightedFalsePositiveRate
0234     0.19...
0235     >>> metrics.weightedPrecision
0236     0.68...
0237     >>> metrics.weightedRecall
0238     0.66...
0239     >>> metrics.weightedFMeasure()
0240     0.66...
0241     >>> metrics.weightedFMeasure(2.0)
0242     0.65...
0243     >>> predictionAndLabelsWithProbabilities = sc.parallelize([
0244     ...      (1.0, 1.0, 1.0, [0.1, 0.8, 0.1]), (0.0, 2.0, 1.0, [0.9, 0.05, 0.05]),
0245     ...      (0.0, 0.0, 1.0, [0.8, 0.2, 0.0]), (1.0, 1.0, 1.0, [0.3, 0.65, 0.05])])
0246     >>> metrics = MulticlassMetrics(predictionAndLabelsWithProbabilities)
0247     >>> metrics.logLoss()
0248     0.9682...
0249 
0250     .. versionadded:: 1.4.0
0251     """
0252 
0253     def __init__(self, predictionAndLabels):
0254         sc = predictionAndLabels.ctx
0255         sql_ctx = SQLContext.getOrCreate(sc)
0256         numCol = len(predictionAndLabels.first())
0257         schema = StructType([
0258             StructField("prediction", DoubleType(), nullable=False),
0259             StructField("label", DoubleType(), nullable=False)])
0260         if numCol >= 3:
0261             schema.add("weight", DoubleType(), False)
0262         if numCol == 4:
0263             schema.add("probability", ArrayType(DoubleType(), False), False)
0264         df = sql_ctx.createDataFrame(predictionAndLabels, schema)
0265         java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
0266         java_model = java_class(df._jdf)
0267         super(MulticlassMetrics, self).__init__(java_model)
0268 
0269     @since('1.4.0')
0270     def confusionMatrix(self):
0271         """
0272         Returns confusion matrix: predicted classes are in columns,
0273         they are ordered by class label ascending, as in "labels".
0274         """
0275         return self.call("confusionMatrix")
0276 
0277     @since('1.4.0')
0278     def truePositiveRate(self, label):
0279         """
0280         Returns true positive rate for a given label (category).
0281         """
0282         return self.call("truePositiveRate", label)
0283 
0284     @since('1.4.0')
0285     def falsePositiveRate(self, label):
0286         """
0287         Returns false positive rate for a given label (category).
0288         """
0289         return self.call("falsePositiveRate", label)
0290 
0291     @since('1.4.0')
0292     def precision(self, label):
0293         """
0294         Returns precision.
0295         """
0296         return self.call("precision", float(label))
0297 
0298     @since('1.4.0')
0299     def recall(self, label):
0300         """
0301         Returns recall.
0302         """
0303         return self.call("recall", float(label))
0304 
0305     @since('1.4.0')
0306     def fMeasure(self, label, beta=None):
0307         """
0308         Returns f-measure.
0309         """
0310         if beta is None:
0311             return self.call("fMeasure", label)
0312         else:
0313             return self.call("fMeasure", label, beta)
0314 
0315     @property
0316     @since('2.0.0')
0317     def accuracy(self):
0318         """
0319         Returns accuracy (equals to the total number of correctly classified instances
0320         out of the total number of instances).
0321         """
0322         return self.call("accuracy")
0323 
0324     @property
0325     @since('1.4.0')
0326     def weightedTruePositiveRate(self):
0327         """
0328         Returns weighted true positive rate.
0329         (equals to precision, recall and f-measure)
0330         """
0331         return self.call("weightedTruePositiveRate")
0332 
0333     @property
0334     @since('1.4.0')
0335     def weightedFalsePositiveRate(self):
0336         """
0337         Returns weighted false positive rate.
0338         """
0339         return self.call("weightedFalsePositiveRate")
0340 
0341     @property
0342     @since('1.4.0')
0343     def weightedRecall(self):
0344         """
0345         Returns weighted averaged recall.
0346         (equals to precision, recall and f-measure)
0347         """
0348         return self.call("weightedRecall")
0349 
0350     @property
0351     @since('1.4.0')
0352     def weightedPrecision(self):
0353         """
0354         Returns weighted averaged precision.
0355         """
0356         return self.call("weightedPrecision")
0357 
0358     @since('1.4.0')
0359     def weightedFMeasure(self, beta=None):
0360         """
0361         Returns weighted averaged f-measure.
0362         """
0363         if beta is None:
0364             return self.call("weightedFMeasure")
0365         else:
0366             return self.call("weightedFMeasure", beta)
0367 
0368     @since('3.0.0')
0369     def logLoss(self, eps=1e-15):
0370         """
0371         Returns weighted logLoss.
0372         """
0373         return self.call("logLoss", eps)
0374 
0375 
0376 class RankingMetrics(JavaModelWrapper):
0377     """
0378     Evaluator for ranking algorithms.
0379 
0380     :param predictionAndLabels: an RDD of (predicted ranking,
0381                                 ground truth set) pairs.
0382 
0383     >>> predictionAndLabels = sc.parallelize([
0384     ...     ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
0385     ...     ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
0386     ...     ([1, 2, 3, 4, 5], [])])
0387     >>> metrics = RankingMetrics(predictionAndLabels)
0388     >>> metrics.precisionAt(1)
0389     0.33...
0390     >>> metrics.precisionAt(5)
0391     0.26...
0392     >>> metrics.precisionAt(15)
0393     0.17...
0394     >>> metrics.meanAveragePrecision
0395     0.35...
0396     >>> metrics.meanAveragePrecisionAt(1)
0397     0.3333333333333333...
0398     >>> metrics.meanAveragePrecisionAt(2)
0399     0.25...
0400     >>> metrics.ndcgAt(3)
0401     0.33...
0402     >>> metrics.ndcgAt(10)
0403     0.48...
0404     >>> metrics.recallAt(1)
0405     0.06...
0406     >>> metrics.recallAt(5)
0407     0.35...
0408     >>> metrics.recallAt(15)
0409     0.66...
0410 
0411     .. versionadded:: 1.4.0
0412     """
0413 
0414     def __init__(self, predictionAndLabels):
0415         sc = predictionAndLabels.ctx
0416         sql_ctx = SQLContext.getOrCreate(sc)
0417         df = sql_ctx.createDataFrame(predictionAndLabels,
0418                                      schema=sql_ctx._inferSchema(predictionAndLabels))
0419         java_model = callMLlibFunc("newRankingMetrics", df._jdf)
0420         super(RankingMetrics, self).__init__(java_model)
0421 
0422     @since('1.4.0')
0423     def precisionAt(self, k):
0424         """
0425         Compute the average precision of all the queries, truncated at ranking position k.
0426 
0427         If for a query, the ranking algorithm returns n (n < k) results, the precision value
0428         will be computed as #(relevant items retrieved) / k. This formula also applies when
0429         the size of the ground truth set is less than k.
0430 
0431         If a query has an empty ground truth set, zero will be used as precision together
0432         with a log warning.
0433         """
0434         return self.call("precisionAt", int(k))
0435 
0436     @property
0437     @since('1.4.0')
0438     def meanAveragePrecision(self):
0439         """
0440         Returns the mean average precision (MAP) of all the queries.
0441         If a query has an empty ground truth set, the average precision will be zero and
0442         a log warining is generated.
0443         """
0444         return self.call("meanAveragePrecision")
0445 
0446     @since('3.0.0')
0447     def meanAveragePrecisionAt(self, k):
0448         """
0449         Returns the mean average precision (MAP) at first k ranking of all the queries.
0450         If a query has an empty ground truth set, the average precision will be zero and
0451         a log warining is generated.
0452         """
0453         return self.call("meanAveragePrecisionAt", int(k))
0454 
0455     @since('1.4.0')
0456     def ndcgAt(self, k):
0457         """
0458         Compute the average NDCG value of all the queries, truncated at ranking position k.
0459         The discounted cumulative gain at position k is computed as:
0460         sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
0461         and the NDCG is obtained by dividing the DCG value on the ground truth set.
0462         In the current implementation, the relevance value is binary.
0463         If a query has an empty ground truth set, zero will be used as NDCG together with
0464         a log warning.
0465         """
0466         return self.call("ndcgAt", int(k))
0467 
0468     @since('3.0.0')
0469     def recallAt(self, k):
0470         """
0471         Compute the average recall of all the queries, truncated at ranking position k.
0472 
0473         If for a query, the ranking algorithm returns n results, the recall value
0474         will be computed as #(relevant items retrieved) / #(ground truth set).
0475         This formula also applies when the size of the ground truth set is less than k.
0476 
0477         If a query has an empty ground truth set, zero will be used as recall together
0478         with a log warning.
0479         """
0480         return self.call("recallAt", int(k))
0481 
0482 
0483 class MultilabelMetrics(JavaModelWrapper):
0484     """
0485     Evaluator for multilabel classification.
0486 
0487     :param predictionAndLabels: an RDD of (predictions, labels) pairs,
0488                                 both are non-null Arrays, each with
0489                                 unique elements.
0490 
0491     >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
0492     ...     ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
0493     ...     ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
0494     >>> metrics = MultilabelMetrics(predictionAndLabels)
0495     >>> metrics.precision(0.0)
0496     1.0
0497     >>> metrics.recall(1.0)
0498     0.66...
0499     >>> metrics.f1Measure(2.0)
0500     0.5
0501     >>> metrics.precision()
0502     0.66...
0503     >>> metrics.recall()
0504     0.64...
0505     >>> metrics.f1Measure()
0506     0.63...
0507     >>> metrics.microPrecision
0508     0.72...
0509     >>> metrics.microRecall
0510     0.66...
0511     >>> metrics.microF1Measure
0512     0.69...
0513     >>> metrics.hammingLoss
0514     0.33...
0515     >>> metrics.subsetAccuracy
0516     0.28...
0517     >>> metrics.accuracy
0518     0.54...
0519 
0520     .. versionadded:: 1.4.0
0521     """
0522 
0523     def __init__(self, predictionAndLabels):
0524         sc = predictionAndLabels.ctx
0525         sql_ctx = SQLContext.getOrCreate(sc)
0526         df = sql_ctx.createDataFrame(predictionAndLabels,
0527                                      schema=sql_ctx._inferSchema(predictionAndLabels))
0528         java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
0529         java_model = java_class(df._jdf)
0530         super(MultilabelMetrics, self).__init__(java_model)
0531 
0532     @since('1.4.0')
0533     def precision(self, label=None):
0534         """
0535         Returns precision or precision for a given label (category) if specified.
0536         """
0537         if label is None:
0538             return self.call("precision")
0539         else:
0540             return self.call("precision", float(label))
0541 
0542     @since('1.4.0')
0543     def recall(self, label=None):
0544         """
0545         Returns recall or recall for a given label (category) if specified.
0546         """
0547         if label is None:
0548             return self.call("recall")
0549         else:
0550             return self.call("recall", float(label))
0551 
0552     @since('1.4.0')
0553     def f1Measure(self, label=None):
0554         """
0555         Returns f1Measure or f1Measure for a given label (category) if specified.
0556         """
0557         if label is None:
0558             return self.call("f1Measure")
0559         else:
0560             return self.call("f1Measure", float(label))
0561 
0562     @property
0563     @since('1.4.0')
0564     def microPrecision(self):
0565         """
0566         Returns micro-averaged label-based precision.
0567         (equals to micro-averaged document-based precision)
0568         """
0569         return self.call("microPrecision")
0570 
0571     @property
0572     @since('1.4.0')
0573     def microRecall(self):
0574         """
0575         Returns micro-averaged label-based recall.
0576         (equals to micro-averaged document-based recall)
0577         """
0578         return self.call("microRecall")
0579 
0580     @property
0581     @since('1.4.0')
0582     def microF1Measure(self):
0583         """
0584         Returns micro-averaged label-based f1-measure.
0585         (equals to micro-averaged document-based f1-measure)
0586         """
0587         return self.call("microF1Measure")
0588 
0589     @property
0590     @since('1.4.0')
0591     def hammingLoss(self):
0592         """
0593         Returns Hamming-loss.
0594         """
0595         return self.call("hammingLoss")
0596 
0597     @property
0598     @since('1.4.0')
0599     def subsetAccuracy(self):
0600         """
0601         Returns subset accuracy.
0602         (for equal sets of labels)
0603         """
0604         return self.call("subsetAccuracy")
0605 
0606     @property
0607     @since('1.4.0')
0608     def accuracy(self):
0609         """
0610         Returns accuracy.
0611         """
0612         return self.call("accuracy")
0613 
0614 
0615 def _test():
0616     import doctest
0617     import numpy
0618     from pyspark.sql import SparkSession
0619     import pyspark.mllib.evaluation
0620     try:
0621         # Numpy 1.14+ changed it's string format.
0622         numpy.set_printoptions(legacy='1.13')
0623     except TypeError:
0624         pass
0625     globs = pyspark.mllib.evaluation.__dict__.copy()
0626     spark = SparkSession.builder\
0627         .master("local[4]")\
0628         .appName("mllib.evaluation tests")\
0629         .getOrCreate()
0630     globs['sc'] = spark.sparkContext
0631     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0632     spark.stop()
0633     if failure_count:
0634         sys.exit(-1)
0635 
0636 
0637 if __name__ == "__main__":
0638     _test()