0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.feature;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import com.google.common.base.Strings;
0024
0025 import scala.Tuple2;
0026
0027 import org.junit.Assert;
0028 import org.junit.Test;
0029
0030 import org.apache.spark.SharedSparkSession;
0031 import org.apache.spark.api.java.JavaRDD;
0032
0033 public class JavaWord2VecSuite extends SharedSparkSession {
0034
0035 @Test
0036 @SuppressWarnings("unchecked")
0037 public void word2Vec() {
0038
0039 String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
0040 List<String> words = Arrays.asList(sentence.split(" "));
0041 List<List<String>> localDoc = Arrays.asList(words, words);
0042 JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
0043 Word2Vec word2vec = new Word2Vec()
0044 .setVectorSize(10)
0045 .setSeed(42L);
0046 Word2VecModel model = word2vec.fit(doc);
0047 Tuple2<String, Object>[] syms = model.findSynonyms("a", 2);
0048 Assert.assertEquals(2, syms.length);
0049 Assert.assertEquals("b", syms[0]._1());
0050 Assert.assertEquals("c", syms[1]._1());
0051 }
0052 }