0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.clustering.KMeansModel;
0022 import org.apache.spark.ml.clustering.KMeans;
0023 import org.apache.spark.ml.evaluation.ClusteringEvaluator;
0024 import org.apache.spark.ml.linalg.Vector;
0025 import org.apache.spark.sql.Dataset;
0026 import org.apache.spark.sql.Row;
0027
0028 import org.apache.spark.sql.SparkSession;
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038 public class JavaKMeansExample {
0039
0040 public static void main(String[] args) {
0041
0042 SparkSession spark = SparkSession
0043 .builder()
0044 .appName("JavaKMeansExample")
0045 .getOrCreate();
0046
0047
0048
0049 Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
0050
0051
0052 KMeans kmeans = new KMeans().setK(2).setSeed(1L);
0053 KMeansModel model = kmeans.fit(dataset);
0054
0055
0056 Dataset<Row> predictions = model.transform(dataset);
0057
0058
0059 ClusteringEvaluator evaluator = new ClusteringEvaluator();
0060
0061 double silhouette = evaluator.evaluate(predictions);
0062 System.out.println("Silhouette with squared euclidean distance = " + silhouette);
0063
0064
0065 Vector[] centers = model.clusterCenters();
0066 System.out.println("Cluster Centers: ");
0067 for (Vector center: centers) {
0068 System.out.println(center);
0069 }
0070
0071
0072 spark.stop();
0073 }
0074 }