Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.fpm;
0019 
0020 import java.io.File;
0021 import java.util.Arrays;
0022 import java.util.List;
0023 
0024 import static org.junit.Assert.assertEquals;
0025 
0026 import org.junit.Test;
0027 
0028 import org.apache.spark.SharedSparkSession;
0029 import org.apache.spark.api.java.JavaRDD;
0030 import org.apache.spark.util.Utils;
0031 
0032 public class JavaFPGrowthSuite extends SharedSparkSession {
0033 
0034   @Test
0035   public void runFPGrowth() {
0036 
0037     @SuppressWarnings("unchecked")
0038     JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
0039       Arrays.asList("r z h k p".split(" ")),
0040       Arrays.asList("z y x w v u t s".split(" ")),
0041       Arrays.asList("s x o n r".split(" ")),
0042       Arrays.asList("x z y m t s q e".split(" ")),
0043       Arrays.asList("z".split(" ")),
0044       Arrays.asList("x z y r q t p".split(" "))), 2);
0045 
0046     FPGrowthModel<String> model = new FPGrowth()
0047       .setMinSupport(0.5)
0048       .setNumPartitions(2)
0049       .run(rdd);
0050 
0051     List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
0052     assertEquals(18, freqItemsets.size());
0053 
0054     for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
0055       // Test return types.
0056       List<String> items = itemset.javaItems();
0057       long freq = itemset.freq();
0058     }
0059   }
0060 
0061   @Test
0062   public void runFPGrowthSaveLoad() {
0063 
0064     @SuppressWarnings("unchecked")
0065     JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
0066       Arrays.asList("r z h k p".split(" ")),
0067       Arrays.asList("z y x w v u t s".split(" ")),
0068       Arrays.asList("s x o n r".split(" ")),
0069       Arrays.asList("x z y m t s q e".split(" ")),
0070       Arrays.asList("z".split(" ")),
0071       Arrays.asList("x z y r q t p".split(" "))), 2);
0072 
0073     FPGrowthModel<String> model = new FPGrowth()
0074       .setMinSupport(0.5)
0075       .setNumPartitions(2)
0076       .run(rdd);
0077 
0078     File tempDir = Utils.createTempDir(
0079       System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite");
0080     String outputPath = tempDir.getPath();
0081 
0082     try {
0083       model.save(spark.sparkContext(), outputPath);
0084       @SuppressWarnings("unchecked")
0085       FPGrowthModel<String> newModel =
0086         (FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath);
0087       List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
0088         .collect();
0089       assertEquals(18, freqItemsets.size());
0090 
0091       for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
0092         // Test return types.
0093         List<String> items = itemset.javaItems();
0094         long freq = itemset.freq();
0095       }
0096     } finally {
0097       Utils.deleteRecursively(tempDir);
0098     }
0099   }
0100 }