0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.feature;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import org.apache.spark.mllib.linalg.Vector;
0029
0030 public class JavaTfIdfSuite extends SharedSparkSession {
0031
0032 @Test
0033 public void tfIdf() {
0034
0035 HashingTF tf = new HashingTF();
0036 @SuppressWarnings("unchecked")
0037 JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
0038 Arrays.asList("this is a sentence".split(" ")),
0039 Arrays.asList("this is another sentence".split(" ")),
0040 Arrays.asList("this is still a sentence".split(" "))), 2);
0041 JavaRDD<Vector> termFreqs = tf.transform(documents);
0042 termFreqs.collect();
0043 IDF idf = new IDF();
0044 JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
0045 List<Vector> localTfIdfs = tfIdfs.collect();
0046 int indexOfThis = tf.indexOf("this");
0047 for (Vector v : localTfIdfs) {
0048 Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
0049 }
0050 }
0051
0052 @Test
0053 public void tfIdfMinimumDocumentFrequency() {
0054
0055 HashingTF tf = new HashingTF();
0056 @SuppressWarnings("unchecked")
0057 JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
0058 Arrays.asList("this is a sentence".split(" ")),
0059 Arrays.asList("this is another sentence".split(" ")),
0060 Arrays.asList("this is still a sentence".split(" "))), 2);
0061 JavaRDD<Vector> termFreqs = tf.transform(documents);
0062 termFreqs.collect();
0063 IDF idf = new IDF(2);
0064 JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
0065 List<Vector> localTfIdfs = tfIdfs.collect();
0066 int indexOfThis = tf.indexOf("this");
0067 for (Vector v : localTfIdfs) {
0068 Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
0069 }
0070 }
0071
0072 }