0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019
0020 from collections import namedtuple
0021
0022 from pyspark import since
0023 from pyspark.rdd import ignore_unicode_prefix
0024 from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
0025 from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc
0026
0027 __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel']
0028
0029
0030 @inherit_doc
0031 @ignore_unicode_prefix
0032 class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
0033 """
0034 A FP-Growth model for mining frequent itemsets
0035 using the Parallel FP-Growth algorithm.
0036
0037 >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
0038 >>> rdd = sc.parallelize(data, 2)
0039 >>> model = FPGrowth.train(rdd, 0.6, 2)
0040 >>> sorted(model.freqItemsets().collect())
0041 [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
0042 >>> model_path = temp_path + "/fpm"
0043 >>> model.save(sc, model_path)
0044 >>> sameModel = FPGrowthModel.load(sc, model_path)
0045 >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect())
0046 True
0047
0048 .. versionadded:: 1.4.0
0049 """
0050
0051 @since("1.4.0")
0052 def freqItemsets(self):
0053 """
0054 Returns the frequent itemsets of this model.
0055 """
0056 return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
0057
0058 @classmethod
0059 @since("2.0.0")
0060 def load(cls, sc, path):
0061 """
0062 Load a model from the given path.
0063 """
0064 model = cls._load_java(sc, path)
0065 wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model)
0066 return FPGrowthModel(wrapper)
0067
0068
0069 class FPGrowth(object):
0070 """
0071 A Parallel FP-growth algorithm to mine frequent itemsets.
0072
0073 .. versionadded:: 1.4.0
0074 """
0075
0076 @classmethod
0077 @since("1.4.0")
0078 def train(cls, data, minSupport=0.3, numPartitions=-1):
0079 """
0080 Computes an FP-Growth model that contains frequent itemsets.
0081
0082 :param data:
0083 The input data set, each element contains a transaction.
0084 :param minSupport:
0085 The minimal support level.
0086 (default: 0.3)
0087 :param numPartitions:
0088 The number of partitions used by parallel FP-growth. A value
0089 of -1 will use the same number as input data.
0090 (default: -1)
0091 """
0092 model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
0093 return FPGrowthModel(model)
0094
0095 class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])):
0096 """
0097 Represents an (items, freq) tuple.
0098
0099 .. versionadded:: 1.4.0
0100 """
0101
0102
0103 @inherit_doc
0104 @ignore_unicode_prefix
0105 class PrefixSpanModel(JavaModelWrapper):
0106 """
0107 Model fitted by PrefixSpan
0108
0109 >>> data = [
0110 ... [["a", "b"], ["c"]],
0111 ... [["a"], ["c", "b"], ["a", "b"]],
0112 ... [["a", "b"], ["e"]],
0113 ... [["f"]]]
0114 >>> rdd = sc.parallelize(data, 2)
0115 >>> model = PrefixSpan.train(rdd)
0116 >>> sorted(model.freqSequences().collect())
0117 [FreqSequence(sequence=[[u'a']], freq=3), FreqSequence(sequence=[[u'a'], [u'a']], freq=1), ...
0118
0119 .. versionadded:: 1.6.0
0120 """
0121
0122 @since("1.6.0")
0123 def freqSequences(self):
0124 """Gets frequent sequences"""
0125 return self.call("getFreqSequences").map(lambda x: PrefixSpan.FreqSequence(x[0], x[1]))
0126
0127
0128 class PrefixSpan(object):
0129 """
0130 A parallel PrefixSpan algorithm to mine frequent sequential patterns.
0131 The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan:
0132 Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth
0133 ([[https://doi.org/10.1109/ICDE.2001.914830]]).
0134
0135 .. versionadded:: 1.6.0
0136 """
0137
0138 @classmethod
0139 @since("1.6.0")
0140 def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000):
0141 """
0142 Finds the complete set of frequent sequential patterns in the
0143 input sequences of itemsets.
0144
0145 :param data:
0146 The input data set, each element contains a sequence of
0147 itemsets.
0148 :param minSupport:
0149 The minimal support level of the sequential pattern, any
0150 pattern that appears more than (minSupport *
0151 size-of-the-dataset) times will be output.
0152 (default: 0.1)
0153 :param maxPatternLength:
0154 The maximal length of the sequential pattern, any pattern
0155 that appears less than maxPatternLength will be output.
0156 (default: 10)
0157 :param maxLocalProjDBSize:
0158 The maximum number of items (including delimiters used in the
0159 internal storage format) allowed in a projected database before
0160 local processing. If a projected database exceeds this size,
0161 another iteration of distributed prefix growth is run.
0162 (default: 32000000)
0163 """
0164 model = callMLlibFunc("trainPrefixSpanModel",
0165 data, minSupport, maxPatternLength, maxLocalProjDBSize)
0166 return PrefixSpanModel(model)
0167
0168 class FreqSequence(namedtuple("FreqSequence", ["sequence", "freq"])):
0169 """
0170 Represents a (sequence, freq) tuple.
0171
0172 .. versionadded:: 1.6.0
0173 """
0174
0175
0176 def _test():
0177 import doctest
0178 from pyspark.sql import SparkSession
0179 import pyspark.mllib.fpm
0180 globs = pyspark.mllib.fpm.__dict__.copy()
0181 spark = SparkSession.builder\
0182 .master("local[4]")\
0183 .appName("mllib.fpm tests")\
0184 .getOrCreate()
0185 globs['sc'] = spark.sparkContext
0186 import tempfile
0187
0188 temp_path = tempfile.mkdtemp()
0189 globs['temp_path'] = temp_path
0190 try:
0191 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
0192 spark.stop()
0193 finally:
0194 from shutil import rmtree
0195 try:
0196 rmtree(temp_path)
0197 except OSError:
0198 pass
0199 if failure_count:
0200 sys.exit(-1)
0201
0202
0203 if __name__ == "__main__":
0204 _test()