0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.clustering;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import static org.junit.Assert.assertEquals;
0024
0025 import org.junit.Test;
0026
0027 import org.apache.spark.SharedSparkSession;
0028 import org.apache.spark.api.java.JavaRDD;
0029 import org.apache.spark.mllib.linalg.Vector;
0030 import org.apache.spark.mllib.linalg.Vectors;
0031
0032 public class JavaKMeansSuite extends SharedSparkSession {
0033
0034 @Test
0035 public void runKMeansUsingStaticMethods() {
0036 List<Vector> points = Arrays.asList(
0037 Vectors.dense(1.0, 2.0, 6.0),
0038 Vectors.dense(1.0, 3.0, 0.0),
0039 Vectors.dense(1.0, 4.0, 6.0)
0040 );
0041
0042 Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
0043
0044 JavaRDD<Vector> data = jsc.parallelize(points, 2);
0045 KMeansModel model = KMeans.train(data.rdd(), 1, 1, KMeans.K_MEANS_PARALLEL());
0046 assertEquals(1, model.clusterCenters().length);
0047 assertEquals(expectedCenter, model.clusterCenters()[0]);
0048
0049 model = KMeans.train(data.rdd(), 1, 1, KMeans.RANDOM());
0050 assertEquals(expectedCenter, model.clusterCenters()[0]);
0051 }
0052
0053 @Test
0054 public void runKMeansUsingConstructor() {
0055 List<Vector> points = Arrays.asList(
0056 Vectors.dense(1.0, 2.0, 6.0),
0057 Vectors.dense(1.0, 3.0, 0.0),
0058 Vectors.dense(1.0, 4.0, 6.0)
0059 );
0060
0061 Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
0062
0063 JavaRDD<Vector> data = jsc.parallelize(points, 2);
0064 KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
0065 assertEquals(1, model.clusterCenters().length);
0066 assertEquals(expectedCenter, model.clusterCenters()[0]);
0067
0068 model = new KMeans()
0069 .setK(1)
0070 .setMaxIterations(1)
0071 .setInitializationMode(KMeans.RANDOM())
0072 .run(data.rdd());
0073 assertEquals(expectedCenter, model.clusterCenters()[0]);
0074 }
0075
0076 @Test
0077 public void testPredictJavaRDD() {
0078 List<Vector> points = Arrays.asList(
0079 Vectors.dense(1.0, 2.0, 6.0),
0080 Vectors.dense(1.0, 3.0, 0.0),
0081 Vectors.dense(1.0, 4.0, 6.0)
0082 );
0083 JavaRDD<Vector> data = jsc.parallelize(points, 2);
0084 KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
0085 JavaRDD<Integer> predictions = model.predict(data);
0086
0087 predictions.first();
0088 }
0089 }