0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.clustering;
0019
0020 import java.io.IOException;
0021 import java.util.Arrays;
0022 import java.util.List;
0023
0024 import org.junit.Test;
0025 import static org.junit.Assert.assertEquals;
0026 import static org.junit.Assert.assertTrue;
0027
0028 import org.apache.spark.SharedSparkSession;
0029 import org.apache.spark.ml.linalg.Vector;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032
0033 public class JavaKMeansSuite extends SharedSparkSession {
0034
0035 private transient int k = 5;
0036 private transient Dataset<Row> dataset;
0037
0038 @Override
0039 public void setUp() throws IOException {
0040 super.setUp();
0041 dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
0042 }
0043
0044 @Test
0045 public void fitAndTransform() {
0046 KMeans kmeans = new KMeans().setK(k).setSeed(1);
0047 KMeansModel model = kmeans.fit(dataset);
0048
0049 Vector[] centers = model.clusterCenters();
0050 assertEquals(k, centers.length);
0051
0052 Dataset<Row> transformed = model.transform(dataset);
0053 List<String> columns = Arrays.asList(transformed.columns());
0054 List<String> expectedColumns = Arrays.asList("features", "prediction");
0055 for (String column : expectedColumns) {
0056 assertTrue(columns.contains(column));
0057 }
0058 }
0059 }