0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 An example demonstrating Logistic Regression Summary.
0020 Run with:
0021 bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py
0022 """
0023 from __future__ import print_function
0024
0025
0026 from pyspark.ml.classification import LogisticRegression
0027
0028 from pyspark.sql import SparkSession
0029
0030 if __name__ == "__main__":
0031 spark = SparkSession \
0032 .builder \
0033 .appName("LogisticRegressionSummary") \
0034 .getOrCreate()
0035
0036
0037 training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
0038
0039 lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
0040
0041
0042 lrModel = lr.fit(training)
0043
0044
0045
0046
0047 trainingSummary = lrModel.summary
0048
0049
0050 objectiveHistory = trainingSummary.objectiveHistory
0051 print("objectiveHistory:")
0052 for objective in objectiveHistory:
0053 print(objective)
0054
0055
0056 trainingSummary.roc.show()
0057 print("areaUnderROC: " + str(trainingSummary.areaUnderROC))
0058
0059
0060 fMeasure = trainingSummary.fMeasureByThreshold
0061 maxFMeasure = fMeasure.groupBy().max('F-Measure').select('max(F-Measure)').head()
0062 bestThreshold = fMeasure.where(fMeasure['F-Measure'] == maxFMeasure['max(F-Measure)']) \
0063 .select('threshold').head()['threshold']
0064 lr.setThreshold(bestThreshold)
0065
0066
0067 spark.stop()