0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.clustering;
0019
0020 import java.io.IOException;
0021 import java.util.ArrayList;
0022 import java.util.Arrays;
0023 import java.util.List;
0024
0025 import scala.Tuple2;
0026 import scala.Tuple3;
0027
0028 import org.junit.Test;
0029 import static org.junit.Assert.*;
0030
0031 import org.apache.spark.SharedSparkSession;
0032 import org.apache.spark.api.java.JavaPairRDD;
0033 import org.apache.spark.api.java.JavaRDD;
0034 import org.apache.spark.mllib.linalg.Matrix;
0035 import org.apache.spark.mllib.linalg.Vector;
0036 import org.apache.spark.mllib.linalg.Vectors;
0037
0038 public class JavaLDASuite extends SharedSparkSession {
0039 @Override
0040 public void setUp() throws IOException {
0041 super.setUp();
0042 List<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
0043 for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
0044 tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
0045 LDASuite.tinyCorpus()[i]._2()));
0046 }
0047 JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
0048 corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
0049 }
0050
0051 @Test
0052 public void localLDAModel() {
0053 Matrix topics = LDASuite.tinyTopics();
0054 double[] topicConcentration = new double[topics.numRows()];
0055 Arrays.fill(topicConcentration, 1.0D / topics.numRows());
0056 LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1.0, 100.0);
0057
0058
0059 assertEquals(model.k(), tinyK);
0060 assertEquals(model.vocabSize(), tinyVocabSize);
0061 assertEquals(model.topicsMatrix(), tinyTopics);
0062
0063
0064 Tuple2<int[], double[]>[] fullTopicSummary = model.describeTopics();
0065 assertEquals(fullTopicSummary.length, tinyK);
0066 for (int i = 0; i < fullTopicSummary.length; i++) {
0067 assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1());
0068 assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5);
0069 }
0070 }
0071
0072 @Test
0073 public void distributedLDAModel() {
0074 int k = 3;
0075 double topicSmoothing = 1.2;
0076 double termSmoothing = 1.2;
0077
0078
0079 LDA lda = new LDA();
0080 lda.setK(k)
0081 .setDocConcentration(topicSmoothing)
0082 .setTopicConcentration(termSmoothing)
0083 .setMaxIterations(5)
0084 .setSeed(12345);
0085
0086 DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
0087
0088
0089 LocalLDAModel localModel = model.toLocal();
0090 assertEquals(k, model.k());
0091 assertEquals(k, localModel.k());
0092 assertEquals(tinyVocabSize, model.vocabSize());
0093 assertEquals(tinyVocabSize, localModel.vocabSize());
0094 assertEquals(localModel.topicsMatrix(), model.topicsMatrix());
0095
0096
0097 Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
0098 assertEquals(k, roundedTopicSummary.length);
0099 Tuple2<int[], double[]>[] roundedLocalTopicSummary = localModel.describeTopics();
0100 assertEquals(k, roundedLocalTopicSummary.length);
0101
0102
0103 assertTrue(model.logLikelihood() < 0.0);
0104 assertTrue(model.logPrior() < 0.0);
0105
0106
0107 JavaPairRDD<Long, Vector> topicDistributions = model.javaTopicDistributions();
0108
0109
0110 JavaPairRDD<Long, Vector> nonEmptyCorpus =
0111 corpus.filter(tuple2 -> Vectors.norm(tuple2._2(), 1.0) != 0.0);
0112 assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
0113
0114
0115 Tuple3<Long, int[], double[]> topTopics = model.javaTopTopicsPerDocument(3).first();
0116 Long docId = topTopics._1();
0117 int[] topicIndices = topTopics._2();
0118 double[] topicWeights = topTopics._3();
0119 assertEquals(3, topicIndices.length);
0120 assertEquals(3, topicWeights.length);
0121
0122
0123 Tuple3<Long, int[], int[]> topicAssignment = model.javaTopicAssignments().first();
0124 Long docId2 = topicAssignment._1();
0125 int[] termIndices2 = topicAssignment._2();
0126 int[] topicIndices2 = topicAssignment._3();
0127 assertEquals(termIndices2.length, topicIndices2.length);
0128 }
0129
0130 @Test
0131 public void onlineOptimizerCompatibility() {
0132 int k = 3;
0133 double topicSmoothing = 1.2;
0134 double termSmoothing = 1.2;
0135
0136
0137 OnlineLDAOptimizer op = new OnlineLDAOptimizer()
0138 .setTau0(1024)
0139 .setKappa(0.51)
0140 .setGammaShape(1e40)
0141 .setMiniBatchFraction(0.5);
0142
0143 LDA lda = new LDA();
0144 lda.setK(k)
0145 .setDocConcentration(topicSmoothing)
0146 .setTopicConcentration(termSmoothing)
0147 .setMaxIterations(5)
0148 .setSeed(12345)
0149 .setOptimizer(op);
0150
0151 LDAModel model = lda.run(corpus);
0152
0153
0154 assertEquals(k, model.k());
0155 assertEquals(tinyVocabSize, model.vocabSize());
0156
0157
0158 Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
0159 assertEquals(k, roundedTopicSummary.length);
0160 Tuple2<int[], double[]>[] roundedLocalTopicSummary = model.describeTopics();
0161 assertEquals(k, roundedLocalTopicSummary.length);
0162 }
0163
0164 @Test
0165 public void localLdaMethods() {
0166 JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
0167 JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
0168
0169
0170 assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count());
0171
0172
0173 double logPerplexity = toyModel.logPerplexity(pairedDocs);
0174
0175
0176 List<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
0177 docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
0178 JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
0179 double logLikelihood = toyModel.logLikelihood(single);
0180 }
0181
0182 private static int tinyK = LDASuite.tinyK();
0183 private static int tinyVocabSize = LDASuite.tinyVocabSize();
0184 private static Matrix tinyTopics = LDASuite.tinyTopics();
0185 private static Tuple2<int[], double[]>[] tinyTopicDescription =
0186 LDASuite.tinyTopicDescription();
0187 private JavaPairRDD<Long, Vector> corpus;
0188 private LocalLDAModel toyModel = LDASuite.toyModel();
0189 private List<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();
0190
0191 }