0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import java.util.Arrays;
0022 import java.util.List;
0023
0024 import org.apache.spark.ml.feature.Word2Vec;
0025 import org.apache.spark.ml.feature.Word2VecModel;
0026 import org.apache.spark.ml.linalg.Vector;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029 import org.apache.spark.sql.RowFactory;
0030 import org.apache.spark.sql.SparkSession;
0031 import org.apache.spark.sql.types.*;
0032
0033
0034 public class JavaWord2VecExample {
0035 public static void main(String[] args) {
0036 SparkSession spark = SparkSession
0037 .builder()
0038 .appName("JavaWord2VecExample")
0039 .getOrCreate();
0040
0041
0042
0043 List<Row> data = Arrays.asList(
0044 RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
0045 RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
0046 RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))
0047 );
0048 StructType schema = new StructType(new StructField[]{
0049 new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
0050 });
0051 Dataset<Row> documentDF = spark.createDataFrame(data, schema);
0052
0053
0054 Word2Vec word2Vec = new Word2Vec()
0055 .setInputCol("text")
0056 .setOutputCol("result")
0057 .setVectorSize(3)
0058 .setMinCount(0);
0059
0060 Word2VecModel model = word2Vec.fit(documentDF);
0061 Dataset<Row> result = model.transform(documentDF);
0062
0063 for (Row row : result.collectAsList()) {
0064 List<String> text = row.getList(0);
0065 Vector vector = (Vector) row.get(1);
0066 System.out.println("Text: " + text + " => \nVector: " + vector + "\n");
0067 }
0068
0069
0070 spark.stop();
0071 }
0072 }