0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020 import org.apache.spark.sql.SparkSession;
0021
0022
0023 import java.util.Arrays;
0024
0025 import org.apache.spark.ml.feature.VectorAssembler;
0026 import org.apache.spark.ml.feature.VectorSizeHint;
0027 import org.apache.spark.ml.linalg.VectorUDT;
0028 import org.apache.spark.ml.linalg.Vectors;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 import org.apache.spark.sql.RowFactory;
0032 import org.apache.spark.sql.types.StructField;
0033 import org.apache.spark.sql.types.StructType;
0034 import static org.apache.spark.sql.types.DataTypes.*;
0035
0036
0037 public class JavaVectorSizeHintExample {
0038 public static void main(String[] args) {
0039 SparkSession spark = SparkSession
0040 .builder()
0041 .appName("JavaVectorSizeHintExample")
0042 .getOrCreate();
0043
0044
0045 StructType schema = createStructType(new StructField[]{
0046 createStructField("id", IntegerType, false),
0047 createStructField("hour", IntegerType, false),
0048 createStructField("mobile", DoubleType, false),
0049 createStructField("userFeatures", new VectorUDT(), false),
0050 createStructField("clicked", DoubleType, false)
0051 });
0052 Row row0 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0);
0053 Row row1 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0);
0054 Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row0, row1), schema);
0055
0056 VectorSizeHint sizeHint = new VectorSizeHint()
0057 .setInputCol("userFeatures")
0058 .setHandleInvalid("skip")
0059 .setSize(3);
0060
0061 Dataset<Row> datasetWithSize = sizeHint.transform(dataset);
0062 System.out.println("Rows where 'userFeatures' is not the right size are filtered out");
0063 datasetWithSize.show(false);
0064
0065 VectorAssembler assembler = new VectorAssembler()
0066 .setInputCols(new String[]{"hour", "mobile", "userFeatures"})
0067 .setOutputCol("features");
0068
0069
0070 Dataset<Row> output = assembler.transform(datasetWithSize);
0071 System.out.println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column " +
0072 "'features'");
0073 output.select("features", "clicked").show(false);
0074
0075
0076 spark.stop();
0077 }
0078 }
0079