0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.feature;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029 import org.apache.spark.sql.RowFactory;
0030 import org.apache.spark.sql.types.DataTypes;
0031 import org.apache.spark.sql.types.Metadata;
0032 import org.apache.spark.sql.types.StructField;
0033 import org.apache.spark.sql.types.StructType;
0034
0035 public class JavaBucketizerSuite extends SharedSparkSession {
0036
0037 @Test
0038 public void bucketizerTest() {
0039 double[] splits = {-0.5, 0.0, 0.5};
0040
0041 StructType schema = new StructType(new StructField[]{
0042 new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
0043 });
0044 Dataset<Row> dataset = spark.createDataFrame(
0045 Arrays.asList(
0046 RowFactory.create(-0.5),
0047 RowFactory.create(-0.3),
0048 RowFactory.create(0.0),
0049 RowFactory.create(0.2)),
0050 schema);
0051
0052 Bucketizer bucketizer = new Bucketizer()
0053 .setInputCol("feature")
0054 .setOutputCol("result")
0055 .setSplits(splits);
0056
0057 List<Row> result = bucketizer.transform(dataset).select("result").collectAsList();
0058
0059 for (Row r : result) {
0060 double index = r.getDouble(0);
0061 Assert.assertTrue((index >= 0) && (index <= 1));
0062 }
0063 }
0064
0065 @Test
0066 public void bucketizerMultipleColumnsTest() {
0067 double[][] splitsArray = {
0068 {-0.5, 0.0, 0.5},
0069 {-0.5, 0.0, 0.2, 0.5}
0070 };
0071
0072 StructType schema = new StructType(new StructField[]{
0073 new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()),
0074 new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()),
0075 });
0076 Dataset<Row> dataset = spark.createDataFrame(
0077 Arrays.asList(
0078 RowFactory.create(-0.5, -0.5),
0079 RowFactory.create(-0.3, -0.3),
0080 RowFactory.create(0.0, 0.0),
0081 RowFactory.create(0.2, 0.3)),
0082 schema);
0083
0084 Bucketizer bucketizer = new Bucketizer()
0085 .setInputCols(new String[] {"feature1", "feature2"})
0086 .setOutputCols(new String[] {"result1", "result2"})
0087 .setSplitsArray(splitsArray);
0088
0089 List<Row> result = bucketizer.transform(dataset).select("result1", "result2").collectAsList();
0090
0091 for (Row r : result) {
0092 double index1 = r.getDouble(0);
0093 Assert.assertTrue((index1 >= 0) && (index1 <= 1));
0094
0095 double index2 = r.getDouble(1);
0096 Assert.assertTrue((index2 >= 0) && (index2 <= 2));
0097 }
0098 }
0099 }