Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 An example of how to use DataFrame for ML. Run with::
0020     bin/spark-submit examples/src/main/python/ml/dataframe_example.py <input_path>
0021 """
0022 from __future__ import print_function
0023 
0024 import os
0025 import sys
0026 import tempfile
0027 import shutil
0028 
0029 from pyspark.sql import SparkSession
0030 from pyspark.mllib.stat import Statistics
0031 from pyspark.mllib.util import MLUtils
0032 
0033 if __name__ == "__main__":
0034     if len(sys.argv) > 2:
0035         print("Usage: dataframe_example.py <libsvm file>", file=sys.stderr)
0036         sys.exit(-1)
0037     elif len(sys.argv) == 2:
0038         input_path = sys.argv[1]
0039     else:
0040         input_path = "data/mllib/sample_libsvm_data.txt"
0041 
0042     spark = SparkSession \
0043         .builder \
0044         .appName("DataFrameExample") \
0045         .getOrCreate()
0046 
0047     # Load an input file
0048     print("Loading LIBSVM file with UDT from " + input_path + ".")
0049     df = spark.read.format("libsvm").load(input_path).cache()
0050     print("Schema from LIBSVM:")
0051     df.printSchema()
0052     print("Loaded training data as a DataFrame with " +
0053           str(df.count()) + " records.")
0054 
0055     # Show statistical summary of labels.
0056     labelSummary = df.describe("label")
0057     labelSummary.show()
0058 
0059     # Convert features column to an RDD of vectors.
0060     features = MLUtils.convertVectorColumnsFromML(df, "features") \
0061         .select("features").rdd.map(lambda r: r.features)
0062     summary = Statistics.colStats(features)
0063     print("Selected features column with average values:\n" +
0064           str(summary.mean()))
0065 
0066     # Save the records in a parquet file.
0067     tempdir = tempfile.NamedTemporaryFile(delete=False).name
0068     os.unlink(tempdir)
0069     print("Saving to " + tempdir + " as Parquet file.")
0070     df.write.parquet(tempdir)
0071 
0072     # Load the records back.
0073     print("Loading Parquet file with UDT from " + tempdir)
0074     newDF = spark.read.parquet(tempdir)
0075     print("Schema from Parquet:")
0076     newDF.printSchema()
0077     shutil.rmtree(tempdir)
0078 
0079     spark.stop()