0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import print_function
0019
0020
0021 from pyspark.ml.classification import LogisticRegression
0022
0023 from pyspark.sql import SparkSession
0024
0025 if __name__ == "__main__":
0026 spark = SparkSession \
0027 .builder \
0028 .appName("MulticlassLogisticRegressionWithElasticNet") \
0029 .getOrCreate()
0030
0031
0032
0033 training = spark \
0034 .read \
0035 .format("libsvm") \
0036 .load("data/mllib/sample_multiclass_classification_data.txt")
0037
0038 lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
0039
0040
0041 lrModel = lr.fit(training)
0042
0043
0044 print("Coefficients: \n" + str(lrModel.coefficientMatrix))
0045 print("Intercept: " + str(lrModel.interceptVector))
0046
0047 trainingSummary = lrModel.summary
0048
0049
0050 objectiveHistory = trainingSummary.objectiveHistory
0051 print("objectiveHistory:")
0052 for objective in objectiveHistory:
0053 print(objective)
0054
0055
0056 print("False positive rate by label:")
0057 for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
0058 print("label %d: %s" % (i, rate))
0059
0060 print("True positive rate by label:")
0061 for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
0062 print("label %d: %s" % (i, rate))
0063
0064 print("Precision by label:")
0065 for i, prec in enumerate(trainingSummary.precisionByLabel):
0066 print("label %d: %s" % (i, prec))
0067
0068 print("Recall by label:")
0069 for i, rec in enumerate(trainingSummary.recallByLabel):
0070 print("label %d: %s" % (i, rec))
0071
0072 print("F-measure by label:")
0073 for i, f in enumerate(trainingSummary.fMeasureByLabel()):
0074 print("label %d: %s" % (i, f))
0075
0076 accuracy = trainingSummary.accuracy
0077 falsePositiveRate = trainingSummary.weightedFalsePositiveRate
0078 truePositiveRate = trainingSummary.weightedTruePositiveRate
0079 fMeasure = trainingSummary.weightedFMeasure()
0080 precision = trainingSummary.weightedPrecision
0081 recall = trainingSummary.weightedRecall
0082 print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
0083 % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))
0084
0085
0086 spark.stop()