0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 """
0019 The K-means algorithm written from scratch against PySpark. In practice,
0020 one may prefer to use the KMeans algorithm in ML, as shown in
0021 examples/src/main/python/ml/kmeans_example.py.
0022
0023 This example requires NumPy (http://www.numpy.org/).
0024 """
0025 from __future__ import print_function
0026
0027 import sys
0028
0029 import numpy as np
0030 from pyspark.sql import SparkSession
0031
0032
0033 def parseVector(line):
0034 return np.array([float(x) for x in line.split(' ')])
0035
0036
0037 def closestPoint(p, centers):
0038 bestIndex = 0
0039 closest = float("+inf")
0040 for i in range(len(centers)):
0041 tempDist = np.sum((p - centers[i]) ** 2)
0042 if tempDist < closest:
0043 closest = tempDist
0044 bestIndex = i
0045 return bestIndex
0046
0047
0048 if __name__ == "__main__":
0049
0050 if len(sys.argv) != 4:
0051 print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
0052 sys.exit(-1)
0053
0054 print("""WARN: This is a naive implementation of KMeans Clustering and is given
0055 as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
0056 example on how to use ML's KMeans implementation.""", file=sys.stderr)
0057
0058 spark = SparkSession\
0059 .builder\
0060 .appName("PythonKMeans")\
0061 .getOrCreate()
0062
0063 lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
0064 data = lines.map(parseVector).cache()
0065 K = int(sys.argv[2])
0066 convergeDist = float(sys.argv[3])
0067
0068 kPoints = data.takeSample(False, K, 1)
0069 tempDist = 1.0
0070
0071 while tempDist > convergeDist:
0072 closest = data.map(
0073 lambda p: (closestPoint(p, kPoints), (p, 1)))
0074 pointStats = closest.reduceByKey(
0075 lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
0076 newPoints = pointStats.map(
0077 lambda st: (st[0], st[1][0] / st[1][1])).collect()
0078
0079 tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints)
0080
0081 for (iK, p) in newPoints:
0082 kPoints[iK] = p
0083
0084 print("Final centers: " + str(kPoints))
0085
0086 spark.stop()