Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.clustering;
0019 
0020 import java.util.Arrays;
0021 import java.util.List;
0022 
0023 import static org.junit.Assert.assertEquals;
0024 
0025 import org.junit.Test;
0026 
0027 import org.apache.spark.SharedSparkSession;
0028 import org.apache.spark.api.java.JavaRDD;
0029 import org.apache.spark.mllib.linalg.Vector;
0030 import org.apache.spark.mllib.linalg.Vectors;
0031 
0032 public class JavaKMeansSuite extends SharedSparkSession {
0033 
0034   @Test
0035   public void runKMeansUsingStaticMethods() {
0036     List<Vector> points = Arrays.asList(
0037       Vectors.dense(1.0, 2.0, 6.0),
0038       Vectors.dense(1.0, 3.0, 0.0),
0039       Vectors.dense(1.0, 4.0, 6.0)
0040     );
0041 
0042     Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
0043 
0044     JavaRDD<Vector> data = jsc.parallelize(points, 2);
0045     KMeansModel model = KMeans.train(data.rdd(), 1, 1, KMeans.K_MEANS_PARALLEL());
0046     assertEquals(1, model.clusterCenters().length);
0047     assertEquals(expectedCenter, model.clusterCenters()[0]);
0048 
0049     model = KMeans.train(data.rdd(), 1, 1, KMeans.RANDOM());
0050     assertEquals(expectedCenter, model.clusterCenters()[0]);
0051   }
0052 
0053   @Test
0054   public void runKMeansUsingConstructor() {
0055     List<Vector> points = Arrays.asList(
0056       Vectors.dense(1.0, 2.0, 6.0),
0057       Vectors.dense(1.0, 3.0, 0.0),
0058       Vectors.dense(1.0, 4.0, 6.0)
0059     );
0060 
0061     Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
0062 
0063     JavaRDD<Vector> data = jsc.parallelize(points, 2);
0064     KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
0065     assertEquals(1, model.clusterCenters().length);
0066     assertEquals(expectedCenter, model.clusterCenters()[0]);
0067 
0068     model = new KMeans()
0069       .setK(1)
0070       .setMaxIterations(1)
0071       .setInitializationMode(KMeans.RANDOM())
0072       .run(data.rdd());
0073     assertEquals(expectedCenter, model.clusterCenters()[0]);
0074   }
0075 
0076   @Test
0077   public void testPredictJavaRDD() {
0078     List<Vector> points = Arrays.asList(
0079       Vectors.dense(1.0, 2.0, 6.0),
0080       Vectors.dense(1.0, 3.0, 0.0),
0081       Vectors.dense(1.0, 4.0, 6.0)
0082     );
0083     JavaRDD<Vector> data = jsc.parallelize(points, 2);
0084     KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
0085     JavaRDD<Integer> predictions = model.predict(data);
0086     // Should be able to get the first prediction.
0087     predictions.first();
0088   }
0089 }