Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 Correlations using MLlib.
0020 """
0021 from __future__ import print_function
0022 
0023 import sys
0024 
0025 from pyspark import SparkContext
0026 from pyspark.mllib.regression import LabeledPoint
0027 from pyspark.mllib.stat import Statistics
0028 from pyspark.mllib.util import MLUtils
0029 
0030 
0031 if __name__ == "__main__":
0032     if len(sys.argv) not in [1, 2]:
0033         print("Usage: correlations (<file>)", file=sys.stderr)
0034         sys.exit(-1)
0035     sc = SparkContext(appName="PythonCorrelations")
0036     if len(sys.argv) == 2:
0037         filepath = sys.argv[1]
0038     else:
0039         filepath = 'data/mllib/sample_linear_regression_data.txt'
0040     corrType = 'pearson'
0041 
0042     points = MLUtils.loadLibSVMFile(sc, filepath)\
0043         .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray()))
0044 
0045     print()
0046     print('Summary of data file: ' + filepath)
0047     print('%d data points' % points.count())
0048 
0049     # Statistics (correlations)
0050     print()
0051     print('Correlation (%s) between label and each feature' % corrType)
0052     print('Feature\tCorrelation')
0053     numFeatures = points.take(1)[0].features.size
0054     labelRDD = points.map(lambda lp: lp.label)
0055     for i in range(numFeatures):
0056         featureRDD = points.map(lambda lp: lp.features[i])
0057         corr = Statistics.corr(labelRDD, featureRDD, corrType)
0058         print('%d\t%g' % (i, corr))
0059     print()
0060 
0061     sc.stop()