Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from __future__ import print_function
0019 
0020 # $example on$
0021 from pyspark.ml.classification import MultilayerPerceptronClassifier
0022 from pyspark.ml.evaluation import MulticlassClassificationEvaluator
0023 # $example off$
0024 from pyspark.sql import SparkSession
0025 
0026 if __name__ == "__main__":
0027     spark = SparkSession\
0028         .builder.appName("multilayer_perceptron_classification_example").getOrCreate()
0029 
0030     # $example on$
0031     # Load training data
0032     data = spark.read.format("libsvm")\
0033         .load("data/mllib/sample_multiclass_classification_data.txt")
0034 
0035     # Split the data into train and test
0036     splits = data.randomSplit([0.6, 0.4], 1234)
0037     train = splits[0]
0038     test = splits[1]
0039 
0040     # specify layers for the neural network:
0041     # input layer of size 4 (features), two intermediate of size 5 and 4
0042     # and output of size 3 (classes)
0043     layers = [4, 5, 4, 3]
0044 
0045     # create the trainer and set its parameters
0046     trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234)
0047 
0048     # train the model
0049     model = trainer.fit(train)
0050 
0051     # compute accuracy on the test set
0052     result = model.transform(test)
0053     predictionAndLabels = result.select("prediction", "label")
0054     evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
0055     print("Test set accuracy = " + str(evaluator.evaluate(predictionAndLabels)))
0056     # $example off$
0057 
0058     spark.stop()