0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 FMClassifier Example.
0020 """
0021 from __future__ import print_function
0022
0023
0024 from pyspark.ml import Pipeline
0025 from pyspark.ml.classification import FMClassifier
0026 from pyspark.ml.feature import MinMaxScaler, StringIndexer
0027 from pyspark.ml.evaluation import MulticlassClassificationEvaluator
0028
0029 from pyspark.sql import SparkSession
0030
0031 if __name__ == "__main__":
0032 spark = SparkSession \
0033 .builder \
0034 .appName("FMClassifierExample") \
0035 .getOrCreate()
0036
0037
0038
0039 data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
0040
0041
0042
0043 labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
0044
0045 featureScaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures").fit(data)
0046
0047
0048 (trainingData, testData) = data.randomSplit([0.7, 0.3])
0049
0050
0051 fm = FMClassifier(labelCol="indexedLabel", featuresCol="scaledFeatures", stepSize=0.001)
0052
0053
0054 pipeline = Pipeline(stages=[labelIndexer, featureScaler, fm])
0055
0056
0057 model = pipeline.fit(trainingData)
0058
0059
0060 predictions = model.transform(testData)
0061
0062
0063 predictions.select("prediction", "indexedLabel", "features").show(5)
0064
0065
0066 evaluator = MulticlassClassificationEvaluator(
0067 labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
0068 accuracy = evaluator.evaluate(predictions)
0069 print("Test set accuracy = %g" % accuracy)
0070
0071 fmModel = model.stages[2]
0072 print("Factors: " + str(fmModel.factors))
0073 print("Linear: " + str(fmModel.linear))
0074 print("Intercept: " + str(fmModel.intercept))
0075
0076
0077 spark.stop()