0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020 import org.apache.spark.sql.SparkSession;
0021
0022
0023 import java.util.Arrays;
0024 import java.util.List;
0025
0026 import org.apache.spark.ml.feature.Bucketizer;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029 import org.apache.spark.sql.RowFactory;
0030 import org.apache.spark.sql.types.DataTypes;
0031 import org.apache.spark.sql.types.Metadata;
0032 import org.apache.spark.sql.types.StructField;
0033 import org.apache.spark.sql.types.StructType;
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043 public class JavaBucketizerExample {
0044 public static void main(String[] args) {
0045 SparkSession spark = SparkSession
0046 .builder()
0047 .appName("JavaBucketizerExample")
0048 .getOrCreate();
0049
0050
0051 double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY};
0052
0053 List<Row> data = Arrays.asList(
0054 RowFactory.create(-999.9),
0055 RowFactory.create(-0.5),
0056 RowFactory.create(-0.3),
0057 RowFactory.create(0.0),
0058 RowFactory.create(0.2),
0059 RowFactory.create(999.9)
0060 );
0061 StructType schema = new StructType(new StructField[]{
0062 new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
0063 });
0064 Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
0065
0066 Bucketizer bucketizer = new Bucketizer()
0067 .setInputCol("features")
0068 .setOutputCol("bucketedFeatures")
0069 .setSplits(splits);
0070
0071
0072 Dataset<Row> bucketedData = bucketizer.transform(dataFrame);
0073
0074 System.out.println("Bucketizer output with " + (bucketizer.getSplits().length-1) + " buckets");
0075 bucketedData.show();
0076
0077
0078
0079
0080 double[][] splitsArray = {
0081 {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY},
0082 {Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY}
0083 };
0084
0085 List<Row> data2 = Arrays.asList(
0086 RowFactory.create(-999.9, -999.9),
0087 RowFactory.create(-0.5, -0.2),
0088 RowFactory.create(-0.3, -0.1),
0089 RowFactory.create(0.0, 0.0),
0090 RowFactory.create(0.2, 0.4),
0091 RowFactory.create(999.9, 999.9)
0092 );
0093 StructType schema2 = new StructType(new StructField[]{
0094 new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()),
0095 new StructField("features2", DataTypes.DoubleType, false, Metadata.empty())
0096 });
0097 Dataset<Row> dataFrame2 = spark.createDataFrame(data2, schema2);
0098
0099 Bucketizer bucketizer2 = new Bucketizer()
0100 .setInputCols(new String[] {"features1", "features2"})
0101 .setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"})
0102 .setSplitsArray(splitsArray);
0103
0104 Dataset<Row> bucketedData2 = bucketizer2.transform(dataFrame2);
0105
0106 System.out.println("Bucketizer output with [" +
0107 (bucketizer2.getSplitsArray()[0].length-1) + ", " +
0108 (bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column");
0109 bucketedData2.show();
0110
0111
0112 spark.stop();
0113 }
0114 }
0115
0116