0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.feature;
0019
0020 import java.util.Arrays;
0021
0022 import org.junit.Assert;
0023 import org.junit.Test;
0024
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.ml.linalg.Vector;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029 import org.apache.spark.sql.RowFactory;
0030 import org.apache.spark.sql.types.*;
0031
0032 public class JavaWord2VecSuite extends SharedSparkSession {
0033
0034 @Test
0035 public void testJavaWord2Vec() {
0036 StructType schema = new StructType(new StructField[]{
0037 new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
0038 });
0039 Dataset<Row> documentDF = spark.createDataFrame(
0040 Arrays.asList(
0041 RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
0042 RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
0043 RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))),
0044 schema);
0045
0046 Word2Vec word2Vec = new Word2Vec()
0047 .setInputCol("text")
0048 .setOutputCol("result")
0049 .setVectorSize(3)
0050 .setMinCount(0);
0051 Word2VecModel model = word2Vec.fit(documentDF);
0052 Dataset<Row> result = model.transform(documentDF);
0053
0054 for (Row r : result.select("result").collectAsList()) {
0055 double[] polyFeatures = ((Vector) r.get(0)).toArray();
0056 Assert.assertEquals(3, polyFeatures.length);
0057 }
0058 }
0059 }