Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import platform
0020 import shutil
0021 import warnings
0022 import gc
0023 import itertools
0024 import operator
0025 import random
0026 import sys
0027 
0028 import pyspark.heapq3 as heapq
0029 from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
0030     CompressedSerializer, AutoBatchedSerializer
0031 from pyspark.util import fail_on_stopiteration
0032 
0033 
0034 try:
0035     import psutil
0036 
0037     process = None
0038 
0039     def get_used_memory():
0040         """ Return the used memory in MiB """
0041         global process
0042         if process is None or process._pid != os.getpid():
0043             process = psutil.Process(os.getpid())
0044         if hasattr(process, "memory_info"):
0045             info = process.memory_info()
0046         else:
0047             info = process.get_memory_info()
0048         return info.rss >> 20
0049 
0050 except ImportError:
0051 
0052     def get_used_memory():
0053         """ Return the used memory in MiB """
0054         if platform.system() == 'Linux':
0055             for line in open('/proc/self/status'):
0056                 if line.startswith('VmRSS:'):
0057                     return int(line.split()[1]) >> 10
0058 
0059         else:
0060             warnings.warn("Please install psutil to have better "
0061                           "support with spilling")
0062             if platform.system() == "Darwin":
0063                 import resource
0064                 rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
0065                 return rss >> 20
0066             # TODO: support windows
0067 
0068         return 0
0069 
0070 
0071 def _get_local_dirs(sub):
0072     """ Get all the directories """
0073     path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
0074     dirs = path.split(",")
0075     if len(dirs) > 1:
0076         # different order in different processes and instances
0077         rnd = random.Random(os.getpid() + id(dirs))
0078         random.shuffle(dirs, rnd.random)
0079     return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
0080 
0081 
0082 # global stats
0083 MemoryBytesSpilled = 0
0084 DiskBytesSpilled = 0
0085 
0086 
0087 class Aggregator(object):
0088 
0089     """
0090     Aggregator has tree functions to merge values into combiner.
0091 
0092     createCombiner:  (value) -> combiner
0093     mergeValue:      (combine, value) -> combiner
0094     mergeCombiners:  (combiner, combiner) -> combiner
0095     """
0096 
0097     def __init__(self, createCombiner, mergeValue, mergeCombiners):
0098         self.createCombiner = fail_on_stopiteration(createCombiner)
0099         self.mergeValue = fail_on_stopiteration(mergeValue)
0100         self.mergeCombiners = fail_on_stopiteration(mergeCombiners)
0101 
0102 
0103 class SimpleAggregator(Aggregator):
0104 
0105     """
0106     SimpleAggregator is useful for the cases that combiners have
0107     same type with values
0108     """
0109 
0110     def __init__(self, combiner):
0111         Aggregator.__init__(self, lambda x: x, combiner, combiner)
0112 
0113 
0114 class Merger(object):
0115 
0116     """
0117     Merge shuffled data together by aggregator
0118     """
0119 
0120     def __init__(self, aggregator):
0121         self.agg = aggregator
0122 
0123     def mergeValues(self, iterator):
0124         """ Combine the items by creator and combiner """
0125         raise NotImplementedError
0126 
0127     def mergeCombiners(self, iterator):
0128         """ Merge the combined items by mergeCombiner """
0129         raise NotImplementedError
0130 
0131     def items(self):
0132         """ Return the merged items ad iterator """
0133         raise NotImplementedError
0134 
0135 
0136 def _compressed_serializer(self, serializer=None):
0137     # always use PickleSerializer to simplify implementation
0138     ser = PickleSerializer()
0139     return AutoBatchedSerializer(CompressedSerializer(ser))
0140 
0141 
0142 class ExternalMerger(Merger):
0143 
0144     """
0145     External merger will dump the aggregated data into disks when
0146     memory usage goes above the limit, then merge them together.
0147 
0148     This class works as follows:
0149 
0150     - It repeatedly combine the items and save them in one dict in
0151       memory.
0152 
0153     - When the used memory goes above memory limit, it will split
0154       the combined data into partitions by hash code, dump them
0155       into disk, one file per partition.
0156 
0157     - Then it goes through the rest of the iterator, combine items
0158       into different dict by hash. Until the used memory goes over
0159       memory limit, it dump all the dicts into disks, one file per
0160       dict. Repeat this again until combine all the items.
0161 
0162     - Before return any items, it will load each partition and
0163       combine them separately. Yield them before loading next
0164       partition.
0165 
0166     - During loading a partition, if the memory goes over limit,
0167       it will partition the loaded data and dump them into disks
0168       and load them partition by partition again.
0169 
0170     `data` and `pdata` are used to hold the merged items in memory.
0171     At first, all the data are merged into `data`. Once the used
0172     memory goes over limit, the items in `data` are dumped into
0173     disks, `data` will be cleared, all rest of items will be merged
0174     into `pdata` and then dumped into disks. Before returning, all
0175     the items in `pdata` will be dumped into disks.
0176 
0177     Finally, if any items were spilled into disks, each partition
0178     will be merged into `data` and be yielded, then cleared.
0179 
0180     >>> agg = SimpleAggregator(lambda x, y: x + y)
0181     >>> merger = ExternalMerger(agg, 10)
0182     >>> N = 10000
0183     >>> merger.mergeValues(zip(range(N), range(N)))
0184     >>> assert merger.spills > 0
0185     >>> sum(v for k,v in merger.items())
0186     49995000
0187 
0188     >>> merger = ExternalMerger(agg, 10)
0189     >>> merger.mergeCombiners(zip(range(N), range(N)))
0190     >>> assert merger.spills > 0
0191     >>> sum(v for k,v in merger.items())
0192     49995000
0193     """
0194 
0195     # the max total partitions created recursively
0196     MAX_TOTAL_PARTITIONS = 4096
0197 
0198     def __init__(self, aggregator, memory_limit=512, serializer=None,
0199                  localdirs=None, scale=1, partitions=59, batch=1000):
0200         Merger.__init__(self, aggregator)
0201         self.memory_limit = memory_limit
0202         self.serializer = _compressed_serializer(serializer)
0203         self.localdirs = localdirs or _get_local_dirs(str(id(self)))
0204         # number of partitions when spill data into disks
0205         self.partitions = partitions
0206         # check the memory after # of items merged
0207         self.batch = batch
0208         # scale is used to scale down the hash of key for recursive hash map
0209         self.scale = scale
0210         # un-partitioned merged data
0211         self.data = {}
0212         # partitioned merged data, list of dicts
0213         self.pdata = []
0214         # number of chunks dumped into disks
0215         self.spills = 0
0216         # randomize the hash of key, id(o) is the address of o (aligned by 8)
0217         self._seed = id(self) + 7
0218 
0219     def _get_spill_dir(self, n):
0220         """ Choose one directory for spill by number n """
0221         return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
0222 
0223     def _next_limit(self):
0224         """
0225         Return the next memory limit. If the memory is not released
0226         after spilling, it will dump the data only when the used memory
0227         starts to increase.
0228         """
0229         return max(self.memory_limit, get_used_memory() * 1.05)
0230 
0231     def mergeValues(self, iterator):
0232         """ Combine the items by creator and combiner """
0233         # speedup attribute lookup
0234         creator, comb = self.agg.createCombiner, self.agg.mergeValue
0235         c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
0236         limit = self.memory_limit
0237 
0238         for k, v in iterator:
0239             d = pdata[hfun(k)] if pdata else data
0240             d[k] = comb(d[k], v) if k in d else creator(v)
0241 
0242             c += 1
0243             if c >= batch:
0244                 if get_used_memory() >= limit:
0245                     self._spill()
0246                     limit = self._next_limit()
0247                     batch /= 2
0248                     c = 0
0249                 else:
0250                     batch *= 1.5
0251 
0252         if get_used_memory() >= limit:
0253             self._spill()
0254 
0255     def _partition(self, key):
0256         """ Return the partition for key """
0257         return hash((key, self._seed)) % self.partitions
0258 
0259     def _object_size(self, obj):
0260         """ How much of memory for this obj, assume that all the objects
0261         consume similar bytes of memory
0262         """
0263         return 1
0264 
0265     def mergeCombiners(self, iterator, limit=None):
0266         """ Merge (K,V) pair by mergeCombiner """
0267         if limit is None:
0268             limit = self.memory_limit
0269         # speedup attribute lookup
0270         comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
0271         c, data, pdata, batch = 0, self.data, self.pdata, self.batch
0272         for k, v in iterator:
0273             d = pdata[hfun(k)] if pdata else data
0274             d[k] = comb(d[k], v) if k in d else v
0275             if not limit:
0276                 continue
0277 
0278             c += objsize(v)
0279             if c > batch:
0280                 if get_used_memory() > limit:
0281                     self._spill()
0282                     limit = self._next_limit()
0283                     batch /= 2
0284                     c = 0
0285                 else:
0286                     batch *= 1.5
0287 
0288         if limit and get_used_memory() >= limit:
0289             self._spill()
0290 
0291     def _spill(self):
0292         """
0293         dump already partitioned data into disks.
0294 
0295         It will dump the data in batch for better performance.
0296         """
0297         global MemoryBytesSpilled, DiskBytesSpilled
0298         path = self._get_spill_dir(self.spills)
0299         if not os.path.exists(path):
0300             os.makedirs(path)
0301 
0302         used_memory = get_used_memory()
0303         if not self.pdata:
0304             # The data has not been partitioned, it will iterator the
0305             # dataset once, write them into different files, has no
0306             # additional memory. It only called when the memory goes
0307             # above limit at the first time.
0308 
0309             # open all the files for writing
0310             streams = [open(os.path.join(path, str(i)), 'wb')
0311                        for i in range(self.partitions)]
0312 
0313             for k, v in self.data.items():
0314                 h = self._partition(k)
0315                 # put one item in batch, make it compatible with load_stream
0316                 # it will increase the memory if dump them in batch
0317                 self.serializer.dump_stream([(k, v)], streams[h])
0318 
0319             for s in streams:
0320                 DiskBytesSpilled += s.tell()
0321                 s.close()
0322 
0323             self.data.clear()
0324             self.pdata.extend([{} for i in range(self.partitions)])
0325 
0326         else:
0327             for i in range(self.partitions):
0328                 p = os.path.join(path, str(i))
0329                 with open(p, "wb") as f:
0330                     # dump items in batch
0331                     self.serializer.dump_stream(iter(self.pdata[i].items()), f)
0332                 self.pdata[i].clear()
0333                 DiskBytesSpilled += os.path.getsize(p)
0334 
0335         self.spills += 1
0336         gc.collect()  # release the memory as much as possible
0337         MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
0338 
0339     def items(self):
0340         """ Return all merged items as iterator """
0341         if not self.pdata and not self.spills:
0342             return iter(self.data.items())
0343         return self._external_items()
0344 
0345     def _external_items(self):
0346         """ Return all partitioned items as iterator """
0347         assert not self.data
0348         if any(self.pdata):
0349             self._spill()
0350         # disable partitioning and spilling when merge combiners from disk
0351         self.pdata = []
0352 
0353         try:
0354             for i in range(self.partitions):
0355                 for v in self._merged_items(i):
0356                     yield v
0357                 self.data.clear()
0358 
0359                 # remove the merged partition
0360                 for j in range(self.spills):
0361                     path = self._get_spill_dir(j)
0362                     os.remove(os.path.join(path, str(i)))
0363         finally:
0364             self._cleanup()
0365 
0366     def _merged_items(self, index):
0367         self.data = {}
0368         limit = self._next_limit()
0369         for j in range(self.spills):
0370             path = self._get_spill_dir(j)
0371             p = os.path.join(path, str(index))
0372             # do not check memory during merging
0373             with open(p, "rb") as f:
0374                 self.mergeCombiners(self.serializer.load_stream(f), 0)
0375 
0376             # limit the total partitions
0377             if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
0378                     and j < self.spills - 1
0379                     and get_used_memory() > limit):
0380                 self.data.clear()  # will read from disk again
0381                 gc.collect()  # release the memory as much as possible
0382                 return self._recursive_merged_items(index)
0383 
0384         return self.data.items()
0385 
0386     def _recursive_merged_items(self, index):
0387         """
0388         merge the partitioned items and return the as iterator
0389 
0390         If one partition can not be fit in memory, then them will be
0391         partitioned and merged recursively.
0392         """
0393         subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
0394         m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
0395                            self.scale * self.partitions, self.partitions, self.batch)
0396         m.pdata = [{} for _ in range(self.partitions)]
0397         limit = self._next_limit()
0398 
0399         for j in range(self.spills):
0400             path = self._get_spill_dir(j)
0401             p = os.path.join(path, str(index))
0402             with open(p, 'rb') as f:
0403                 m.mergeCombiners(self.serializer.load_stream(f), 0)
0404 
0405             if get_used_memory() > limit:
0406                 m._spill()
0407                 limit = self._next_limit()
0408 
0409         return m._external_items()
0410 
0411     def _cleanup(self):
0412         """ Clean up all the files in disks """
0413         for d in self.localdirs:
0414             shutil.rmtree(d, True)
0415 
0416 
0417 class ExternalSorter(object):
0418     """
0419     ExtenalSorter will divide the elements into chunks, sort them in
0420     memory and dump them into disks, finally merge them back.
0421 
0422     The spilling will only happen when the used memory goes above
0423     the limit.
0424 
0425 
0426     >>> sorter = ExternalSorter(1)  # 1M
0427     >>> import random
0428     >>> l = list(range(1024))
0429     >>> random.shuffle(l)
0430     >>> sorted(l) == list(sorter.sorted(l))
0431     True
0432     >>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True))
0433     True
0434     """
0435     def __init__(self, memory_limit, serializer=None):
0436         self.memory_limit = memory_limit
0437         self.local_dirs = _get_local_dirs("sort")
0438         self.serializer = _compressed_serializer(serializer)
0439 
0440     def _get_path(self, n):
0441         """ Choose one directory for spill by number n """
0442         d = self.local_dirs[n % len(self.local_dirs)]
0443         if not os.path.exists(d):
0444             os.makedirs(d)
0445         return os.path.join(d, str(n))
0446 
0447     def _next_limit(self):
0448         """
0449         Return the next memory limit. If the memory is not released
0450         after spilling, it will dump the data only when the used memory
0451         starts to increase.
0452         """
0453         return max(self.memory_limit, get_used_memory() * 1.05)
0454 
0455     def sorted(self, iterator, key=None, reverse=False):
0456         """
0457         Sort the elements in iterator, do external sort when the memory
0458         goes above the limit.
0459         """
0460         global MemoryBytesSpilled, DiskBytesSpilled
0461         batch, limit = 100, self._next_limit()
0462         chunks, current_chunk = [], []
0463         iterator = iter(iterator)
0464         while True:
0465             # pick elements in batch
0466             chunk = list(itertools.islice(iterator, batch))
0467             current_chunk.extend(chunk)
0468             if len(chunk) < batch:
0469                 break
0470 
0471             used_memory = get_used_memory()
0472             if used_memory > limit:
0473                 # sort them inplace will save memory
0474                 current_chunk.sort(key=key, reverse=reverse)
0475                 path = self._get_path(len(chunks))
0476                 with open(path, 'wb') as f:
0477                     self.serializer.dump_stream(current_chunk, f)
0478 
0479                 def load(f):
0480                     for v in self.serializer.load_stream(f):
0481                         yield v
0482                     # close the file explicit once we consume all the items
0483                     # to avoid ResourceWarning in Python3
0484                     f.close()
0485                 chunks.append(load(open(path, 'rb')))
0486                 current_chunk = []
0487                 MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
0488                 DiskBytesSpilled += os.path.getsize(path)
0489                 os.unlink(path)  # data will be deleted after close
0490 
0491             elif not chunks:
0492                 batch = min(int(batch * 1.5), 10000)
0493 
0494         current_chunk.sort(key=key, reverse=reverse)
0495         if not chunks:
0496             return current_chunk
0497 
0498         if current_chunk:
0499             chunks.append(iter(current_chunk))
0500 
0501         return heapq.merge(chunks, key=key, reverse=reverse)
0502 
0503 
0504 class ExternalList(object):
0505     """
0506     ExternalList can have many items which cannot be hold in memory in
0507     the same time.
0508 
0509     >>> l = ExternalList(list(range(100)))
0510     >>> len(l)
0511     100
0512     >>> l.append(10)
0513     >>> len(l)
0514     101
0515     >>> for i in range(20240):
0516     ...     l.append(i)
0517     >>> len(l)
0518     20341
0519     >>> import pickle
0520     >>> l2 = pickle.loads(pickle.dumps(l))
0521     >>> len(l2)
0522     20341
0523     >>> list(l2)[100]
0524     10
0525     """
0526     LIMIT = 10240
0527 
0528     def __init__(self, values):
0529         self.values = values
0530         self.count = len(values)
0531         self._file = None
0532         self._ser = None
0533 
0534     def __getstate__(self):
0535         if self._file is not None:
0536             self._file.flush()
0537             with os.fdopen(os.dup(self._file.fileno()), "rb") as f:
0538                 f.seek(0)
0539                 serialized = f.read()
0540         else:
0541             serialized = b''
0542         return self.values, self.count, serialized
0543 
0544     def __setstate__(self, item):
0545         self.values, self.count, serialized = item
0546         if serialized:
0547             self._open_file()
0548             self._file.write(serialized)
0549         else:
0550             self._file = None
0551             self._ser = None
0552 
0553     def __iter__(self):
0554         if self._file is not None:
0555             self._file.flush()
0556             # read all items from disks first
0557             with os.fdopen(os.dup(self._file.fileno()), 'rb') as f:
0558                 f.seek(0)
0559                 for v in self._ser.load_stream(f):
0560                     yield v
0561 
0562         for v in self.values:
0563             yield v
0564 
0565     def __len__(self):
0566         return self.count
0567 
0568     def append(self, value):
0569         self.values.append(value)
0570         self.count += 1
0571         # dump them into disk if the key is huge
0572         if len(self.values) >= self.LIMIT:
0573             self._spill()
0574 
0575     def _open_file(self):
0576         dirs = _get_local_dirs("objects")
0577         d = dirs[id(self) % len(dirs)]
0578         if not os.path.exists(d):
0579             os.makedirs(d)
0580         p = os.path.join(d, str(id(self)))
0581         self._file = open(p, "w+b", 65536)
0582         self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
0583         os.unlink(p)
0584 
0585     def __del__(self):
0586         if self._file:
0587             self._file.close()
0588             self._file = None
0589 
0590     def _spill(self):
0591         """ dump the values into disk """
0592         global MemoryBytesSpilled, DiskBytesSpilled
0593         if self._file is None:
0594             self._open_file()
0595 
0596         used_memory = get_used_memory()
0597         pos = self._file.tell()
0598         self._ser.dump_stream(self.values, self._file)
0599         self.values = []
0600         gc.collect()
0601         DiskBytesSpilled += self._file.tell() - pos
0602         MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
0603 
0604 
0605 class ExternalListOfList(ExternalList):
0606     """
0607     An external list for list.
0608 
0609     >>> l = ExternalListOfList([[i, i] for i in range(100)])
0610     >>> len(l)
0611     200
0612     >>> l.append(range(10))
0613     >>> len(l)
0614     210
0615     >>> len(list(l))
0616     210
0617     """
0618 
0619     def __init__(self, values):
0620         ExternalList.__init__(self, values)
0621         self.count = sum(len(i) for i in values)
0622 
0623     def append(self, value):
0624         ExternalList.append(self, value)
0625         # already counted 1 in ExternalList.append
0626         self.count += len(value) - 1
0627 
0628     def __iter__(self):
0629         for values in ExternalList.__iter__(self):
0630             for v in values:
0631                 yield v
0632 
0633 
0634 class GroupByKey(object):
0635     """
0636     Group a sorted iterator as [(k1, it1), (k2, it2), ...]
0637 
0638     >>> k = [i // 3 for i in range(6)]
0639     >>> v = [[i] for i in range(6)]
0640     >>> g = GroupByKey(zip(k, v))
0641     >>> [(k, list(it)) for k, it in g]
0642     [(0, [0, 1, 2]), (1, [3, 4, 5])]
0643     """
0644 
0645     def __init__(self, iterator):
0646         self.iterator = iterator
0647 
0648     def __iter__(self):
0649         key, values = None, None
0650         for k, v in self.iterator:
0651             if values is not None and k == key:
0652                 values.append(v)
0653             else:
0654                 if values is not None:
0655                     yield (key, values)
0656                 key = k
0657                 values = ExternalListOfList([v])
0658         if values is not None:
0659             yield (key, values)
0660 
0661 
0662 class ExternalGroupBy(ExternalMerger):
0663 
0664     """
0665     Group by the items by key. If any partition of them can not been
0666     hold in memory, it will do sort based group by.
0667 
0668     This class works as follows:
0669 
0670     - It repeatedly group the items by key and save them in one dict in
0671       memory.
0672 
0673     - When the used memory goes above memory limit, it will split
0674       the combined data into partitions by hash code, dump them
0675       into disk, one file per partition. If the number of keys
0676       in one partitions is smaller than 1000, it will sort them
0677       by key before dumping into disk.
0678 
0679     - Then it goes through the rest of the iterator, group items
0680       by key into different dict by hash. Until the used memory goes over
0681       memory limit, it dump all the dicts into disks, one file per
0682       dict. Repeat this again until combine all the items. It
0683       also will try to sort the items by key in each partition
0684       before dumping into disks.
0685 
0686     - It will yield the grouped items partitions by partitions.
0687       If the data in one partitions can be hold in memory, then it
0688       will load and combine them in memory and yield.
0689 
0690     - If the dataset in one partition cannot be hold in memory,
0691       it will sort them first. If all the files are already sorted,
0692       it merge them by heap.merge(), so it will do external sort
0693       for all the files.
0694 
0695     - After sorting, `GroupByKey` class will put all the continuous
0696       items with the same key as a group, yield the values as
0697       an iterator.
0698     """
0699     SORT_KEY_LIMIT = 1000
0700 
0701     def flattened_serializer(self):
0702         assert isinstance(self.serializer, BatchedSerializer)
0703         ser = self.serializer
0704         return FlattenedValuesSerializer(ser, 20)
0705 
0706     def _object_size(self, obj):
0707         return len(obj)
0708 
0709     def _spill(self):
0710         """
0711         dump already partitioned data into disks.
0712         """
0713         global MemoryBytesSpilled, DiskBytesSpilled
0714         path = self._get_spill_dir(self.spills)
0715         if not os.path.exists(path):
0716             os.makedirs(path)
0717 
0718         used_memory = get_used_memory()
0719         if not self.pdata:
0720             # The data has not been partitioned, it will iterator the
0721             # data once, write them into different files, has no
0722             # additional memory. It only called when the memory goes
0723             # above limit at the first time.
0724 
0725             # open all the files for writing
0726             streams = [open(os.path.join(path, str(i)), 'wb')
0727                        for i in range(self.partitions)]
0728 
0729             # If the number of keys is small, then the overhead of sort is small
0730             # sort them before dumping into disks
0731             self._sorted = len(self.data) < self.SORT_KEY_LIMIT
0732             if self._sorted:
0733                 self.serializer = self.flattened_serializer()
0734                 for k in sorted(self.data.keys()):
0735                     h = self._partition(k)
0736                     self.serializer.dump_stream([(k, self.data[k])], streams[h])
0737             else:
0738                 for k, v in self.data.items():
0739                     h = self._partition(k)
0740                     self.serializer.dump_stream([(k, v)], streams[h])
0741 
0742             for s in streams:
0743                 DiskBytesSpilled += s.tell()
0744                 s.close()
0745 
0746             self.data.clear()
0747             # self.pdata is cached in `mergeValues` and `mergeCombiners`
0748             self.pdata.extend([{} for i in range(self.partitions)])
0749 
0750         else:
0751             for i in range(self.partitions):
0752                 p = os.path.join(path, str(i))
0753                 with open(p, "wb") as f:
0754                     # dump items in batch
0755                     if self._sorted:
0756                         # sort by key only (stable)
0757                         sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0))
0758                         self.serializer.dump_stream(sorted_items, f)
0759                     else:
0760                         self.serializer.dump_stream(self.pdata[i].items(), f)
0761                 self.pdata[i].clear()
0762                 DiskBytesSpilled += os.path.getsize(p)
0763 
0764         self.spills += 1
0765         gc.collect()  # release the memory as much as possible
0766         MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
0767 
0768     def _merged_items(self, index):
0769         size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
0770                    for j in range(self.spills))
0771         # if the memory can not hold all the partition,
0772         # then use sort based merge. Because of compression,
0773         # the data on disks will be much smaller than needed memory
0774         if size >= self.memory_limit << 17:  # * 1M / 8
0775             return self._merge_sorted_items(index)
0776 
0777         self.data = {}
0778         for j in range(self.spills):
0779             path = self._get_spill_dir(j)
0780             p = os.path.join(path, str(index))
0781             # do not check memory during merging
0782             with open(p, "rb") as f:
0783                 self.mergeCombiners(self.serializer.load_stream(f), 0)
0784         return self.data.items()
0785 
0786     def _merge_sorted_items(self, index):
0787         """ load a partition from disk, then sort and group by key """
0788         def load_partition(j):
0789             path = self._get_spill_dir(j)
0790             p = os.path.join(path, str(index))
0791             with open(p, 'rb', 65536) as f:
0792                 for v in self.serializer.load_stream(f):
0793                     yield v
0794 
0795         disk_items = [load_partition(j) for j in range(self.spills)]
0796 
0797         if self._sorted:
0798             # all the partitions are already sorted
0799             sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0))
0800 
0801         else:
0802             # Flatten the combined values, so it will not consume huge
0803             # memory during merging sort.
0804             ser = self.flattened_serializer()
0805             sorter = ExternalSorter(self.memory_limit, ser)
0806             sorted_items = sorter.sorted(itertools.chain(*disk_items),
0807                                          key=operator.itemgetter(0))
0808         return ((k, vs) for k, vs in GroupByKey(sorted_items))
0809 
0810 
0811 if __name__ == "__main__":
0812     import doctest
0813     (failure_count, test_count) = doctest.testmod()
0814     if failure_count:
0815         sys.exit(-1)