Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 
0020 from collections import namedtuple
0021 
0022 from pyspark import since
0023 from pyspark.rdd import ignore_unicode_prefix
0024 from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
0025 from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc
0026 
0027 __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel']
0028 
0029 
0030 @inherit_doc
0031 @ignore_unicode_prefix
0032 class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
0033     """
0034     A FP-Growth model for mining frequent itemsets
0035     using the Parallel FP-Growth algorithm.
0036 
0037     >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
0038     >>> rdd = sc.parallelize(data, 2)
0039     >>> model = FPGrowth.train(rdd, 0.6, 2)
0040     >>> sorted(model.freqItemsets().collect())
0041     [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
0042     >>> model_path = temp_path + "/fpm"
0043     >>> model.save(sc, model_path)
0044     >>> sameModel = FPGrowthModel.load(sc, model_path)
0045     >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect())
0046     True
0047 
0048     .. versionadded:: 1.4.0
0049     """
0050 
0051     @since("1.4.0")
0052     def freqItemsets(self):
0053         """
0054         Returns the frequent itemsets of this model.
0055         """
0056         return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
0057 
0058     @classmethod
0059     @since("2.0.0")
0060     def load(cls, sc, path):
0061         """
0062         Load a model from the given path.
0063         """
0064         model = cls._load_java(sc, path)
0065         wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model)
0066         return FPGrowthModel(wrapper)
0067 
0068 
0069 class FPGrowth(object):
0070     """
0071     A Parallel FP-growth algorithm to mine frequent itemsets.
0072 
0073     .. versionadded:: 1.4.0
0074     """
0075 
0076     @classmethod
0077     @since("1.4.0")
0078     def train(cls, data, minSupport=0.3, numPartitions=-1):
0079         """
0080         Computes an FP-Growth model that contains frequent itemsets.
0081 
0082         :param data:
0083           The input data set, each element contains a transaction.
0084         :param minSupport:
0085           The minimal support level.
0086           (default: 0.3)
0087         :param numPartitions:
0088           The number of partitions used by parallel FP-growth. A value
0089           of -1 will use the same number as input data.
0090           (default: -1)
0091         """
0092         model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
0093         return FPGrowthModel(model)
0094 
0095     class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])):
0096         """
0097         Represents an (items, freq) tuple.
0098 
0099         .. versionadded:: 1.4.0
0100         """
0101 
0102 
0103 @inherit_doc
0104 @ignore_unicode_prefix
0105 class PrefixSpanModel(JavaModelWrapper):
0106     """
0107     Model fitted by PrefixSpan
0108 
0109     >>> data = [
0110     ...    [["a", "b"], ["c"]],
0111     ...    [["a"], ["c", "b"], ["a", "b"]],
0112     ...    [["a", "b"], ["e"]],
0113     ...    [["f"]]]
0114     >>> rdd = sc.parallelize(data, 2)
0115     >>> model = PrefixSpan.train(rdd)
0116     >>> sorted(model.freqSequences().collect())
0117     [FreqSequence(sequence=[[u'a']], freq=3), FreqSequence(sequence=[[u'a'], [u'a']], freq=1), ...
0118 
0119     .. versionadded:: 1.6.0
0120     """
0121 
0122     @since("1.6.0")
0123     def freqSequences(self):
0124         """Gets frequent sequences"""
0125         return self.call("getFreqSequences").map(lambda x: PrefixSpan.FreqSequence(x[0], x[1]))
0126 
0127 
0128 class PrefixSpan(object):
0129     """
0130     A parallel PrefixSpan algorithm to mine frequent sequential patterns.
0131     The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan:
0132     Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth
0133     ([[https://doi.org/10.1109/ICDE.2001.914830]]).
0134 
0135     .. versionadded:: 1.6.0
0136     """
0137 
0138     @classmethod
0139     @since("1.6.0")
0140     def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000):
0141         """
0142         Finds the complete set of frequent sequential patterns in the
0143         input sequences of itemsets.
0144 
0145         :param data:
0146           The input data set, each element contains a sequence of
0147           itemsets.
0148         :param minSupport:
0149           The minimal support level of the sequential pattern, any
0150           pattern that appears more than (minSupport *
0151           size-of-the-dataset) times will be output.
0152           (default: 0.1)
0153         :param maxPatternLength:
0154           The maximal length of the sequential pattern, any pattern
0155           that appears less than maxPatternLength will be output.
0156           (default: 10)
0157         :param maxLocalProjDBSize:
0158           The maximum number of items (including delimiters used in the
0159           internal storage format) allowed in a projected database before
0160           local processing. If a projected database exceeds this size,
0161           another iteration of distributed prefix growth is run.
0162           (default: 32000000)
0163         """
0164         model = callMLlibFunc("trainPrefixSpanModel",
0165                               data, minSupport, maxPatternLength, maxLocalProjDBSize)
0166         return PrefixSpanModel(model)
0167 
0168     class FreqSequence(namedtuple("FreqSequence", ["sequence", "freq"])):
0169         """
0170         Represents a (sequence, freq) tuple.
0171 
0172         .. versionadded:: 1.6.0
0173         """
0174 
0175 
0176 def _test():
0177     import doctest
0178     from pyspark.sql import SparkSession
0179     import pyspark.mllib.fpm
0180     globs = pyspark.mllib.fpm.__dict__.copy()
0181     spark = SparkSession.builder\
0182         .master("local[4]")\
0183         .appName("mllib.fpm tests")\
0184         .getOrCreate()
0185     globs['sc'] = spark.sparkContext
0186     import tempfile
0187 
0188     temp_path = tempfile.mkdtemp()
0189     globs['temp_path'] = temp_path
0190     try:
0191         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0192         spark.stop()
0193     finally:
0194         from shutil import rmtree
0195         try:
0196             rmtree(temp_path)
0197         except OSError:
0198             pass
0199     if failure_count:
0200         sys.exit(-1)
0201 
0202 
0203 if __name__ == "__main__":
0204     _test()