0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019
0020 from pyspark import since
0021 from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
0022 from pyspark.sql import SQLContext
0023 from pyspark.sql.types import ArrayType, StructField, StructType, DoubleType
0024
0025 __all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
0026 'MulticlassMetrics', 'RankingMetrics']
0027
0028
0029 class BinaryClassificationMetrics(JavaModelWrapper):
0030 """
0031 Evaluator for binary classification.
0032
0033 :param scoreAndLabels: an RDD of score, label and optional weight.
0034
0035 >>> scoreAndLabels = sc.parallelize([
0036 ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
0037 >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
0038 >>> metrics.areaUnderROC
0039 0.70...
0040 >>> metrics.areaUnderPR
0041 0.83...
0042 >>> metrics.unpersist()
0043 >>> scoreAndLabelsWithOptWeight = sc.parallelize([
0044 ... (0.1, 0.0, 1.0), (0.1, 1.0, 0.4), (0.4, 0.0, 0.2), (0.6, 0.0, 0.6), (0.6, 1.0, 0.9),
0045 ... (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2)
0046 >>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight)
0047 >>> metrics.areaUnderROC
0048 0.79...
0049 >>> metrics.areaUnderPR
0050 0.88...
0051
0052 .. versionadded:: 1.4.0
0053 """
0054
0055 def __init__(self, scoreAndLabels):
0056 sc = scoreAndLabels.ctx
0057 sql_ctx = SQLContext.getOrCreate(sc)
0058 numCol = len(scoreAndLabels.first())
0059 schema = StructType([
0060 StructField("score", DoubleType(), nullable=False),
0061 StructField("label", DoubleType(), nullable=False)])
0062 if numCol == 3:
0063 schema.add("weight", DoubleType(), False)
0064 df = sql_ctx.createDataFrame(scoreAndLabels, schema=schema)
0065 java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
0066 java_model = java_class(df._jdf)
0067 super(BinaryClassificationMetrics, self).__init__(java_model)
0068
0069 @property
0070 @since('1.4.0')
0071 def areaUnderROC(self):
0072 """
0073 Computes the area under the receiver operating characteristic
0074 (ROC) curve.
0075 """
0076 return self.call("areaUnderROC")
0077
0078 @property
0079 @since('1.4.0')
0080 def areaUnderPR(self):
0081 """
0082 Computes the area under the precision-recall curve.
0083 """
0084 return self.call("areaUnderPR")
0085
0086 @since('1.4.0')
0087 def unpersist(self):
0088 """
0089 Unpersists intermediate RDDs used in the computation.
0090 """
0091 self.call("unpersist")
0092
0093
0094 class RegressionMetrics(JavaModelWrapper):
0095 """
0096 Evaluator for regression.
0097
0098 :param predictionAndObservations: an RDD of prediction, observation and optional weight.
0099
0100 >>> predictionAndObservations = sc.parallelize([
0101 ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
0102 >>> metrics = RegressionMetrics(predictionAndObservations)
0103 >>> metrics.explainedVariance
0104 8.859...
0105 >>> metrics.meanAbsoluteError
0106 0.5...
0107 >>> metrics.meanSquaredError
0108 0.37...
0109 >>> metrics.rootMeanSquaredError
0110 0.61...
0111 >>> metrics.r2
0112 0.94...
0113 >>> predictionAndObservationsWithOptWeight = sc.parallelize([
0114 ... (2.5, 3.0, 0.5), (0.0, -0.5, 1.0), (2.0, 2.0, 0.3), (8.0, 7.0, 0.9)])
0115 >>> metrics = RegressionMetrics(predictionAndObservationsWithOptWeight)
0116 >>> metrics.rootMeanSquaredError
0117 0.68...
0118
0119 .. versionadded:: 1.4.0
0120 """
0121
0122 def __init__(self, predictionAndObservations):
0123 sc = predictionAndObservations.ctx
0124 sql_ctx = SQLContext.getOrCreate(sc)
0125 numCol = len(predictionAndObservations.first())
0126 schema = StructType([
0127 StructField("prediction", DoubleType(), nullable=False),
0128 StructField("observation", DoubleType(), nullable=False)])
0129 if numCol == 3:
0130 schema.add("weight", DoubleType(), False)
0131 df = sql_ctx.createDataFrame(predictionAndObservations, schema=schema)
0132 java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
0133 java_model = java_class(df._jdf)
0134 super(RegressionMetrics, self).__init__(java_model)
0135
0136 @property
0137 @since('1.4.0')
0138 def explainedVariance(self):
0139 r"""
0140 Returns the explained variance regression score.
0141 explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}`
0142 """
0143 return self.call("explainedVariance")
0144
0145 @property
0146 @since('1.4.0')
0147 def meanAbsoluteError(self):
0148 """
0149 Returns the mean absolute error, which is a risk function corresponding to the
0150 expected value of the absolute error loss or l1-norm loss.
0151 """
0152 return self.call("meanAbsoluteError")
0153
0154 @property
0155 @since('1.4.0')
0156 def meanSquaredError(self):
0157 """
0158 Returns the mean squared error, which is a risk function corresponding to the
0159 expected value of the squared error loss or quadratic loss.
0160 """
0161 return self.call("meanSquaredError")
0162
0163 @property
0164 @since('1.4.0')
0165 def rootMeanSquaredError(self):
0166 """
0167 Returns the root mean squared error, which is defined as the square root of
0168 the mean squared error.
0169 """
0170 return self.call("rootMeanSquaredError")
0171
0172 @property
0173 @since('1.4.0')
0174 def r2(self):
0175 """
0176 Returns R^2^, the coefficient of determination.
0177 """
0178 return self.call("r2")
0179
0180
0181 class MulticlassMetrics(JavaModelWrapper):
0182 """
0183 Evaluator for multiclass classification.
0184
0185 :param predictionAndLabels: an RDD of prediction, label, optional weight
0186 and optional probability.
0187
0188 >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
0189 ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
0190 >>> metrics = MulticlassMetrics(predictionAndLabels)
0191 >>> metrics.confusionMatrix().toArray()
0192 array([[ 2., 1., 1.],
0193 [ 1., 3., 0.],
0194 [ 0., 0., 1.]])
0195 >>> metrics.falsePositiveRate(0.0)
0196 0.2...
0197 >>> metrics.precision(1.0)
0198 0.75...
0199 >>> metrics.recall(2.0)
0200 1.0...
0201 >>> metrics.fMeasure(0.0, 2.0)
0202 0.52...
0203 >>> metrics.accuracy
0204 0.66...
0205 >>> metrics.weightedFalsePositiveRate
0206 0.19...
0207 >>> metrics.weightedPrecision
0208 0.68...
0209 >>> metrics.weightedRecall
0210 0.66...
0211 >>> metrics.weightedFMeasure()
0212 0.66...
0213 >>> metrics.weightedFMeasure(2.0)
0214 0.65...
0215 >>> predAndLabelsWithOptWeight = sc.parallelize([(0.0, 0.0, 1.0), (0.0, 1.0, 1.0),
0216 ... (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0), (1.0, 1.0, 1.0),
0217 ... (2.0, 2.0, 1.0), (2.0, 0.0, 1.0)])
0218 >>> metrics = MulticlassMetrics(predAndLabelsWithOptWeight)
0219 >>> metrics.confusionMatrix().toArray()
0220 array([[ 2., 1., 1.],
0221 [ 1., 3., 0.],
0222 [ 0., 0., 1.]])
0223 >>> metrics.falsePositiveRate(0.0)
0224 0.2...
0225 >>> metrics.precision(1.0)
0226 0.75...
0227 >>> metrics.recall(2.0)
0228 1.0...
0229 >>> metrics.fMeasure(0.0, 2.0)
0230 0.52...
0231 >>> metrics.accuracy
0232 0.66...
0233 >>> metrics.weightedFalsePositiveRate
0234 0.19...
0235 >>> metrics.weightedPrecision
0236 0.68...
0237 >>> metrics.weightedRecall
0238 0.66...
0239 >>> metrics.weightedFMeasure()
0240 0.66...
0241 >>> metrics.weightedFMeasure(2.0)
0242 0.65...
0243 >>> predictionAndLabelsWithProbabilities = sc.parallelize([
0244 ... (1.0, 1.0, 1.0, [0.1, 0.8, 0.1]), (0.0, 2.0, 1.0, [0.9, 0.05, 0.05]),
0245 ... (0.0, 0.0, 1.0, [0.8, 0.2, 0.0]), (1.0, 1.0, 1.0, [0.3, 0.65, 0.05])])
0246 >>> metrics = MulticlassMetrics(predictionAndLabelsWithProbabilities)
0247 >>> metrics.logLoss()
0248 0.9682...
0249
0250 .. versionadded:: 1.4.0
0251 """
0252
0253 def __init__(self, predictionAndLabels):
0254 sc = predictionAndLabels.ctx
0255 sql_ctx = SQLContext.getOrCreate(sc)
0256 numCol = len(predictionAndLabels.first())
0257 schema = StructType([
0258 StructField("prediction", DoubleType(), nullable=False),
0259 StructField("label", DoubleType(), nullable=False)])
0260 if numCol >= 3:
0261 schema.add("weight", DoubleType(), False)
0262 if numCol == 4:
0263 schema.add("probability", ArrayType(DoubleType(), False), False)
0264 df = sql_ctx.createDataFrame(predictionAndLabels, schema)
0265 java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
0266 java_model = java_class(df._jdf)
0267 super(MulticlassMetrics, self).__init__(java_model)
0268
0269 @since('1.4.0')
0270 def confusionMatrix(self):
0271 """
0272 Returns confusion matrix: predicted classes are in columns,
0273 they are ordered by class label ascending, as in "labels".
0274 """
0275 return self.call("confusionMatrix")
0276
0277 @since('1.4.0')
0278 def truePositiveRate(self, label):
0279 """
0280 Returns true positive rate for a given label (category).
0281 """
0282 return self.call("truePositiveRate", label)
0283
0284 @since('1.4.0')
0285 def falsePositiveRate(self, label):
0286 """
0287 Returns false positive rate for a given label (category).
0288 """
0289 return self.call("falsePositiveRate", label)
0290
0291 @since('1.4.0')
0292 def precision(self, label):
0293 """
0294 Returns precision.
0295 """
0296 return self.call("precision", float(label))
0297
0298 @since('1.4.0')
0299 def recall(self, label):
0300 """
0301 Returns recall.
0302 """
0303 return self.call("recall", float(label))
0304
0305 @since('1.4.0')
0306 def fMeasure(self, label, beta=None):
0307 """
0308 Returns f-measure.
0309 """
0310 if beta is None:
0311 return self.call("fMeasure", label)
0312 else:
0313 return self.call("fMeasure", label, beta)
0314
0315 @property
0316 @since('2.0.0')
0317 def accuracy(self):
0318 """
0319 Returns accuracy (equals to the total number of correctly classified instances
0320 out of the total number of instances).
0321 """
0322 return self.call("accuracy")
0323
0324 @property
0325 @since('1.4.0')
0326 def weightedTruePositiveRate(self):
0327 """
0328 Returns weighted true positive rate.
0329 (equals to precision, recall and f-measure)
0330 """
0331 return self.call("weightedTruePositiveRate")
0332
0333 @property
0334 @since('1.4.0')
0335 def weightedFalsePositiveRate(self):
0336 """
0337 Returns weighted false positive rate.
0338 """
0339 return self.call("weightedFalsePositiveRate")
0340
0341 @property
0342 @since('1.4.0')
0343 def weightedRecall(self):
0344 """
0345 Returns weighted averaged recall.
0346 (equals to precision, recall and f-measure)
0347 """
0348 return self.call("weightedRecall")
0349
0350 @property
0351 @since('1.4.0')
0352 def weightedPrecision(self):
0353 """
0354 Returns weighted averaged precision.
0355 """
0356 return self.call("weightedPrecision")
0357
0358 @since('1.4.0')
0359 def weightedFMeasure(self, beta=None):
0360 """
0361 Returns weighted averaged f-measure.
0362 """
0363 if beta is None:
0364 return self.call("weightedFMeasure")
0365 else:
0366 return self.call("weightedFMeasure", beta)
0367
0368 @since('3.0.0')
0369 def logLoss(self, eps=1e-15):
0370 """
0371 Returns weighted logLoss.
0372 """
0373 return self.call("logLoss", eps)
0374
0375
0376 class RankingMetrics(JavaModelWrapper):
0377 """
0378 Evaluator for ranking algorithms.
0379
0380 :param predictionAndLabels: an RDD of (predicted ranking,
0381 ground truth set) pairs.
0382
0383 >>> predictionAndLabels = sc.parallelize([
0384 ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
0385 ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
0386 ... ([1, 2, 3, 4, 5], [])])
0387 >>> metrics = RankingMetrics(predictionAndLabels)
0388 >>> metrics.precisionAt(1)
0389 0.33...
0390 >>> metrics.precisionAt(5)
0391 0.26...
0392 >>> metrics.precisionAt(15)
0393 0.17...
0394 >>> metrics.meanAveragePrecision
0395 0.35...
0396 >>> metrics.meanAveragePrecisionAt(1)
0397 0.3333333333333333...
0398 >>> metrics.meanAveragePrecisionAt(2)
0399 0.25...
0400 >>> metrics.ndcgAt(3)
0401 0.33...
0402 >>> metrics.ndcgAt(10)
0403 0.48...
0404 >>> metrics.recallAt(1)
0405 0.06...
0406 >>> metrics.recallAt(5)
0407 0.35...
0408 >>> metrics.recallAt(15)
0409 0.66...
0410
0411 .. versionadded:: 1.4.0
0412 """
0413
0414 def __init__(self, predictionAndLabels):
0415 sc = predictionAndLabels.ctx
0416 sql_ctx = SQLContext.getOrCreate(sc)
0417 df = sql_ctx.createDataFrame(predictionAndLabels,
0418 schema=sql_ctx._inferSchema(predictionAndLabels))
0419 java_model = callMLlibFunc("newRankingMetrics", df._jdf)
0420 super(RankingMetrics, self).__init__(java_model)
0421
0422 @since('1.4.0')
0423 def precisionAt(self, k):
0424 """
0425 Compute the average precision of all the queries, truncated at ranking position k.
0426
0427 If for a query, the ranking algorithm returns n (n < k) results, the precision value
0428 will be computed as #(relevant items retrieved) / k. This formula also applies when
0429 the size of the ground truth set is less than k.
0430
0431 If a query has an empty ground truth set, zero will be used as precision together
0432 with a log warning.
0433 """
0434 return self.call("precisionAt", int(k))
0435
0436 @property
0437 @since('1.4.0')
0438 def meanAveragePrecision(self):
0439 """
0440 Returns the mean average precision (MAP) of all the queries.
0441 If a query has an empty ground truth set, the average precision will be zero and
0442 a log warining is generated.
0443 """
0444 return self.call("meanAveragePrecision")
0445
0446 @since('3.0.0')
0447 def meanAveragePrecisionAt(self, k):
0448 """
0449 Returns the mean average precision (MAP) at first k ranking of all the queries.
0450 If a query has an empty ground truth set, the average precision will be zero and
0451 a log warining is generated.
0452 """
0453 return self.call("meanAveragePrecisionAt", int(k))
0454
0455 @since('1.4.0')
0456 def ndcgAt(self, k):
0457 """
0458 Compute the average NDCG value of all the queries, truncated at ranking position k.
0459 The discounted cumulative gain at position k is computed as:
0460 sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
0461 and the NDCG is obtained by dividing the DCG value on the ground truth set.
0462 In the current implementation, the relevance value is binary.
0463 If a query has an empty ground truth set, zero will be used as NDCG together with
0464 a log warning.
0465 """
0466 return self.call("ndcgAt", int(k))
0467
0468 @since('3.0.0')
0469 def recallAt(self, k):
0470 """
0471 Compute the average recall of all the queries, truncated at ranking position k.
0472
0473 If for a query, the ranking algorithm returns n results, the recall value
0474 will be computed as #(relevant items retrieved) / #(ground truth set).
0475 This formula also applies when the size of the ground truth set is less than k.
0476
0477 If a query has an empty ground truth set, zero will be used as recall together
0478 with a log warning.
0479 """
0480 return self.call("recallAt", int(k))
0481
0482
0483 class MultilabelMetrics(JavaModelWrapper):
0484 """
0485 Evaluator for multilabel classification.
0486
0487 :param predictionAndLabels: an RDD of (predictions, labels) pairs,
0488 both are non-null Arrays, each with
0489 unique elements.
0490
0491 >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
0492 ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
0493 ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
0494 >>> metrics = MultilabelMetrics(predictionAndLabels)
0495 >>> metrics.precision(0.0)
0496 1.0
0497 >>> metrics.recall(1.0)
0498 0.66...
0499 >>> metrics.f1Measure(2.0)
0500 0.5
0501 >>> metrics.precision()
0502 0.66...
0503 >>> metrics.recall()
0504 0.64...
0505 >>> metrics.f1Measure()
0506 0.63...
0507 >>> metrics.microPrecision
0508 0.72...
0509 >>> metrics.microRecall
0510 0.66...
0511 >>> metrics.microF1Measure
0512 0.69...
0513 >>> metrics.hammingLoss
0514 0.33...
0515 >>> metrics.subsetAccuracy
0516 0.28...
0517 >>> metrics.accuracy
0518 0.54...
0519
0520 .. versionadded:: 1.4.0
0521 """
0522
0523 def __init__(self, predictionAndLabels):
0524 sc = predictionAndLabels.ctx
0525 sql_ctx = SQLContext.getOrCreate(sc)
0526 df = sql_ctx.createDataFrame(predictionAndLabels,
0527 schema=sql_ctx._inferSchema(predictionAndLabels))
0528 java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
0529 java_model = java_class(df._jdf)
0530 super(MultilabelMetrics, self).__init__(java_model)
0531
0532 @since('1.4.0')
0533 def precision(self, label=None):
0534 """
0535 Returns precision or precision for a given label (category) if specified.
0536 """
0537 if label is None:
0538 return self.call("precision")
0539 else:
0540 return self.call("precision", float(label))
0541
0542 @since('1.4.0')
0543 def recall(self, label=None):
0544 """
0545 Returns recall or recall for a given label (category) if specified.
0546 """
0547 if label is None:
0548 return self.call("recall")
0549 else:
0550 return self.call("recall", float(label))
0551
0552 @since('1.4.0')
0553 def f1Measure(self, label=None):
0554 """
0555 Returns f1Measure or f1Measure for a given label (category) if specified.
0556 """
0557 if label is None:
0558 return self.call("f1Measure")
0559 else:
0560 return self.call("f1Measure", float(label))
0561
0562 @property
0563 @since('1.4.0')
0564 def microPrecision(self):
0565 """
0566 Returns micro-averaged label-based precision.
0567 (equals to micro-averaged document-based precision)
0568 """
0569 return self.call("microPrecision")
0570
0571 @property
0572 @since('1.4.0')
0573 def microRecall(self):
0574 """
0575 Returns micro-averaged label-based recall.
0576 (equals to micro-averaged document-based recall)
0577 """
0578 return self.call("microRecall")
0579
0580 @property
0581 @since('1.4.0')
0582 def microF1Measure(self):
0583 """
0584 Returns micro-averaged label-based f1-measure.
0585 (equals to micro-averaged document-based f1-measure)
0586 """
0587 return self.call("microF1Measure")
0588
0589 @property
0590 @since('1.4.0')
0591 def hammingLoss(self):
0592 """
0593 Returns Hamming-loss.
0594 """
0595 return self.call("hammingLoss")
0596
0597 @property
0598 @since('1.4.0')
0599 def subsetAccuracy(self):
0600 """
0601 Returns subset accuracy.
0602 (for equal sets of labels)
0603 """
0604 return self.call("subsetAccuracy")
0605
0606 @property
0607 @since('1.4.0')
0608 def accuracy(self):
0609 """
0610 Returns accuracy.
0611 """
0612 return self.call("accuracy")
0613
0614
0615 def _test():
0616 import doctest
0617 import numpy
0618 from pyspark.sql import SparkSession
0619 import pyspark.mllib.evaluation
0620 try:
0621
0622 numpy.set_printoptions(legacy='1.13')
0623 except TypeError:
0624 pass
0625 globs = pyspark.mllib.evaluation.__dict__.copy()
0626 spark = SparkSession.builder\
0627 .master("local[4]")\
0628 .appName("mllib.evaluation tests")\
0629 .getOrCreate()
0630 globs['sc'] = spark.sparkContext
0631 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0632 spark.stop()
0633 if failure_count:
0634 sys.exit(-1)
0635
0636
0637 if __name__ == "__main__":
0638 _test()