0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.feature;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022 import java.util.Map;
0023
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026
0027 import org.apache.spark.SharedSparkSession;
0028 import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
0029 import org.apache.spark.ml.linalg.Vectors;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032
0033
0034 public class JavaVectorIndexerSuite extends SharedSparkSession {
0035
0036 @Test
0037 public void vectorIndexerAPI() {
0038
0039 List<FeatureData> points = Arrays.asList(
0040 new FeatureData(Vectors.dense(0.0, -2.0)),
0041 new FeatureData(Vectors.dense(1.0, 3.0)),
0042 new FeatureData(Vectors.dense(1.0, 4.0))
0043 );
0044 Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class);
0045 VectorIndexer indexer = new VectorIndexer()
0046 .setInputCol("features")
0047 .setOutputCol("indexed")
0048 .setMaxCategories(2);
0049 VectorIndexerModel model = indexer.fit(data);
0050 Assert.assertEquals(2, model.numFeatures());
0051 Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps();
0052 Assert.assertEquals(1, categoryMaps.size());
0053 Dataset<Row> indexedData = model.transform(data);
0054 }
0055 }