Back to home page

OSCL-LXR

 
 

    


0001 """
0002 Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
0003 All rights reserved.
0004 
0005 Redistribution and use in source and binary forms, with or without
0006 modification, are permitted provided that the following conditions are
0007 met:
0008 
0009     * Redistributions of source code must retain the above copyright
0010 notice, this list of conditions and the following disclaimer.
0011 
0012     * Redistributions in binary form must reproduce the above
0013 copyright notice, this list of conditions and the following disclaimer
0014 in the documentation and/or other materials provided with the
0015 distribution.
0016 
0017     * Neither the name of the Douban Inc. nor the names of its
0018 contributors may be used to endorse or promote products derived from
0019 this software without specific prior written permission.
0020 
0021 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
0022 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
0023 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
0024 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
0025 OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
0026 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
0027 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
0028 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
0029 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
0030 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
0031 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
0032 """
0033 
0034 from pyspark.resultiterable import ResultIterable
0035 from functools import reduce
0036 
0037 
0038 def _do_python_join(rdd, other, numPartitions, dispatch):
0039     vs = rdd.mapValues(lambda v: (1, v))
0040     ws = other.mapValues(lambda v: (2, v))
0041     return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__()))
0042 
0043 
0044 def python_join(rdd, other, numPartitions):
0045     def dispatch(seq):
0046         vbuf, wbuf = [], []
0047         for (n, v) in seq:
0048             if n == 1:
0049                 vbuf.append(v)
0050             elif n == 2:
0051                 wbuf.append(v)
0052         return ((v, w) for v in vbuf for w in wbuf)
0053     return _do_python_join(rdd, other, numPartitions, dispatch)
0054 
0055 
0056 def python_right_outer_join(rdd, other, numPartitions):
0057     def dispatch(seq):
0058         vbuf, wbuf = [], []
0059         for (n, v) in seq:
0060             if n == 1:
0061                 vbuf.append(v)
0062             elif n == 2:
0063                 wbuf.append(v)
0064         if not vbuf:
0065             vbuf.append(None)
0066         return ((v, w) for v in vbuf for w in wbuf)
0067     return _do_python_join(rdd, other, numPartitions, dispatch)
0068 
0069 
0070 def python_left_outer_join(rdd, other, numPartitions):
0071     def dispatch(seq):
0072         vbuf, wbuf = [], []
0073         for (n, v) in seq:
0074             if n == 1:
0075                 vbuf.append(v)
0076             elif n == 2:
0077                 wbuf.append(v)
0078         if not wbuf:
0079             wbuf.append(None)
0080         return ((v, w) for v in vbuf for w in wbuf)
0081     return _do_python_join(rdd, other, numPartitions, dispatch)
0082 
0083 
0084 def python_full_outer_join(rdd, other, numPartitions):
0085     def dispatch(seq):
0086         vbuf, wbuf = [], []
0087         for (n, v) in seq:
0088             if n == 1:
0089                 vbuf.append(v)
0090             elif n == 2:
0091                 wbuf.append(v)
0092         if not vbuf:
0093             vbuf.append(None)
0094         if not wbuf:
0095             wbuf.append(None)
0096         return ((v, w) for v in vbuf for w in wbuf)
0097     return _do_python_join(rdd, other, numPartitions, dispatch)
0098 
0099 
0100 def python_cogroup(rdds, numPartitions):
0101     def make_mapper(i):
0102         return lambda v: (i, v)
0103     vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
0104     union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
0105     rdd_len = len(vrdds)
0106 
0107     def dispatch(seq):
0108         bufs = [[] for _ in range(rdd_len)]
0109         for n, v in seq:
0110             bufs[n].append(v)
0111         return tuple(ResultIterable(vs) for vs in bufs)
0112 
0113     return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)