0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.clustering;
0019
0020 import java.util.Arrays;
0021
0022 import org.junit.Assert;
0023 import org.junit.Test;
0024
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.mllib.linalg.Vector;
0028 import org.apache.spark.mllib.linalg.Vectors;
0029
0030 public class JavaBisectingKMeansSuite extends SharedSparkSession {
0031
0032 @Test
0033 public void twoDimensionalData() {
0034 JavaRDD<Vector> points = jsc.parallelize(Arrays.asList(
0035 Vectors.dense(4, -1),
0036 Vectors.dense(4, 1),
0037 Vectors.sparse(2, new int[]{0}, new double[]{1.0})
0038 ), 2);
0039
0040 BisectingKMeans bkm = new BisectingKMeans()
0041 .setK(4)
0042 .setMaxIterations(2)
0043 .setSeed(1L);
0044 BisectingKMeansModel model = bkm.run(points);
0045 Assert.assertEquals(3, model.k());
0046 Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12);
0047 for (ClusteringTreeNode child : model.root().children()) {
0048 double[] center = child.center().toArray();
0049 if (center[0] > 2) {
0050 Assert.assertEquals(2, child.size());
0051 Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
0052 } else {
0053 Assert.assertEquals(1, child.size());
0054 Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
0055 }
0056 }
0057 }
0058 }