Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.ml;
0019 
0020 // $example on$
0021 import org.apache.spark.ml.clustering.KMeansModel;
0022 import org.apache.spark.ml.clustering.KMeans;
0023 import org.apache.spark.ml.evaluation.ClusteringEvaluator;
0024 import org.apache.spark.ml.linalg.Vector;
0025 import org.apache.spark.sql.Dataset;
0026 import org.apache.spark.sql.Row;
0027 // $example off$
0028 import org.apache.spark.sql.SparkSession;
0029 
0030 
0031 /**
0032  * An example demonstrating k-means clustering.
0033  * Run with
0034  * <pre>
0035  * bin/run-example ml.JavaKMeansExample
0036  * </pre>
0037  */
0038 public class JavaKMeansExample {
0039 
0040   public static void main(String[] args) {
0041     // Create a SparkSession.
0042     SparkSession spark = SparkSession
0043       .builder()
0044       .appName("JavaKMeansExample")
0045       .getOrCreate();
0046 
0047     // $example on$
0048     // Loads data.
0049     Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
0050 
0051     // Trains a k-means model.
0052     KMeans kmeans = new KMeans().setK(2).setSeed(1L);
0053     KMeansModel model = kmeans.fit(dataset);
0054 
0055     // Make predictions
0056     Dataset<Row> predictions = model.transform(dataset);
0057 
0058     // Evaluate clustering by computing Silhouette score
0059     ClusteringEvaluator evaluator = new ClusteringEvaluator();
0060 
0061     double silhouette = evaluator.evaluate(predictions);
0062     System.out.println("Silhouette with squared euclidean distance = " + silhouette);
0063 
0064     // Shows the result.
0065     Vector[] centers = model.clusterCenters();
0066     System.out.println("Cluster Centers: ");
0067     for (Vector center: centers) {
0068       System.out.println(center);
0069     }
0070     // $example off$
0071 
0072     spark.stop();
0073   }
0074 }