0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 A Gaussian Mixture Model clustering program using MLlib.
0020 """
0021 from __future__ import print_function
0022
0023 import sys
0024 if sys.version >= '3':
0025 long = int
0026
0027 import random
0028 import argparse
0029 import numpy as np
0030
0031 from pyspark import SparkConf, SparkContext
0032 from pyspark.mllib.clustering import GaussianMixture
0033
0034
0035 def parseVector(line):
0036 return np.array([float(x) for x in line.split(' ')])
0037
0038
0039 if __name__ == "__main__":
0040 """
0041 Parameters
0042 ----------
0043 :param inputFile: Input file path which contains data points
0044 :param k: Number of mixture components
0045 :param convergenceTol: Convergence threshold. Default to 1e-3
0046 :param maxIterations: Number of EM iterations to perform. Default to 100
0047 :param seed: Random seed
0048 """
0049
0050 parser = argparse.ArgumentParser()
0051 parser.add_argument('inputFile', help='Input File')
0052 parser.add_argument('k', type=int, help='Number of clusters')
0053 parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold')
0054 parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations')
0055 parser.add_argument('--seed', default=random.getrandbits(19),
0056 type=long, help='Random seed')
0057 args = parser.parse_args()
0058
0059 conf = SparkConf().setAppName("GMM")
0060 sc = SparkContext(conf=conf)
0061
0062 lines = sc.textFile(args.inputFile)
0063 data = lines.map(parseVector)
0064 model = GaussianMixture.train(data, args.k, args.convergenceTol,
0065 args.maxIterations, args.seed)
0066 for i in range(args.k):
0067 print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
0068 "sigma = ", model.gaussians[i].sigma.toArray()))
0069 print("\n")
0070 print(("The membership value of each vector to all mixture components (first 100): ",
0071 model.predictSoft(data).take(100)))
0072 print("\n")
0073 print(("Cluster labels (first 100): ", model.predict(data).take(100)))
0074 sc.stop()