0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import copy
0019 import sys
0020 import os
0021 import re
0022 import operator
0023 import shlex
0024 import warnings
0025 import heapq
0026 import bisect
0027 import random
0028 from subprocess import Popen, PIPE
0029 from tempfile import NamedTemporaryFile
0030 from threading import Thread
0031 from collections import defaultdict
0032 from itertools import chain
0033 from functools import reduce
0034 from math import sqrt, log, isinf, isnan, pow, ceil
0035
0036 if sys.version > '3':
0037 basestring = unicode = str
0038 else:
0039 from itertools import imap as map, ifilter as filter
0040
0041 from pyspark.java_gateway import local_connect_and_auth
0042 from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \
0043 CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \
0044 UTF8Deserializer, pack_long, read_int, write_int
0045 from pyspark.join import python_join, python_left_outer_join, \
0046 python_right_outer_join, python_full_outer_join, python_cogroup
0047 from pyspark.statcounter import StatCounter
0048 from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
0049 from pyspark.storagelevel import StorageLevel
0050 from pyspark.resultiterable import ResultIterable
0051 from pyspark.shuffle import Aggregator, ExternalMerger, \
0052 get_used_memory, ExternalSorter, ExternalGroupBy
0053 from pyspark.traceback_utils import SCCallSiteSync
0054 from pyspark.util import fail_on_stopiteration
0055
0056
0057 __all__ = ["RDD"]
0058
0059
0060 class PythonEvalType(object):
0061 """
0062 Evaluation type of python rdd.
0063
0064 These values are internal to PySpark.
0065
0066 These values should match values in org.apache.spark.api.python.PythonEvalType.
0067 """
0068 NON_UDF = 0
0069
0070 SQL_BATCHED_UDF = 100
0071
0072 SQL_SCALAR_PANDAS_UDF = 200
0073 SQL_GROUPED_MAP_PANDAS_UDF = 201
0074 SQL_GROUPED_AGG_PANDAS_UDF = 202
0075 SQL_WINDOW_AGG_PANDAS_UDF = 203
0076 SQL_SCALAR_PANDAS_ITER_UDF = 204
0077 SQL_MAP_PANDAS_ITER_UDF = 205
0078 SQL_COGROUPED_MAP_PANDAS_UDF = 206
0079
0080
0081 def portable_hash(x):
0082 """
0083 This function returns consistent hash code for builtin types, especially
0084 for None and tuple with None.
0085
0086 The algorithm is similar to that one used by CPython 2.7
0087
0088 >>> portable_hash(None)
0089 0
0090 >>> portable_hash((None, 1)) & 0xffffffff
0091 219750521
0092 """
0093
0094 if sys.version_info >= (3, 2, 3) and 'PYTHONHASHSEED' not in os.environ:
0095 raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED")
0096
0097 if x is None:
0098 return 0
0099 if isinstance(x, tuple):
0100 h = 0x345678
0101 for i in x:
0102 h ^= portable_hash(i)
0103 h *= 1000003
0104 h &= sys.maxsize
0105 h ^= len(x)
0106 if h == -1:
0107 h = -2
0108 return int(h)
0109 return hash(x)
0110
0111
0112 class BoundedFloat(float):
0113 """
0114 Bounded value is generated by approximate job, with confidence and low
0115 bound and high bound.
0116
0117 >>> BoundedFloat(100.0, 0.95, 95.0, 105.0)
0118 100.0
0119 """
0120 def __new__(cls, mean, confidence, low, high):
0121 obj = float.__new__(cls, mean)
0122 obj.confidence = confidence
0123 obj.low = low
0124 obj.high = high
0125 return obj
0126
0127
0128 def _parse_memory(s):
0129 """
0130 Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
0131 return the value in MiB
0132
0133 >>> _parse_memory("256m")
0134 256
0135 >>> _parse_memory("2g")
0136 2048
0137 """
0138 units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
0139 if s[-1].lower() not in units:
0140 raise ValueError("invalid format: " + s)
0141 return int(float(s[:-1]) * units[s[-1].lower()])
0142
0143
0144 def _create_local_socket(sock_info):
0145 """
0146 Create a local socket that can be used to load deserialized data from the JVM
0147
0148 :param sock_info: Tuple containing port number and authentication secret for a local socket.
0149 :return: sockfile file descriptor of the local socket
0150 """
0151 port = sock_info[0]
0152 auth_secret = sock_info[1]
0153 sockfile, sock = local_connect_and_auth(port, auth_secret)
0154
0155
0156 sock.settimeout(None)
0157 return sockfile
0158
0159
0160 def _load_from_socket(sock_info, serializer):
0161 """
0162 Connect to a local socket described by sock_info and use the given serializer to yield data
0163
0164 :param sock_info: Tuple containing port number and authentication secret for a local socket.
0165 :param serializer: The PySpark serializer to use
0166 :return: result of Serializer.load_stream, usually a generator that yields deserialized data
0167 """
0168 sockfile = _create_local_socket(sock_info)
0169
0170 return serializer.load_stream(sockfile)
0171
0172
0173 def _local_iterator_from_socket(sock_info, serializer):
0174
0175 class PyLocalIterable(object):
0176 """ Create a synchronous local iterable over a socket """
0177
0178 def __init__(self, _sock_info, _serializer):
0179 port, auth_secret, self.jsocket_auth_server = _sock_info
0180 self._sockfile = _create_local_socket((port, auth_secret))
0181 self._serializer = _serializer
0182 self._read_iter = iter([])
0183 self._read_status = 1
0184
0185 def __iter__(self):
0186 while self._read_status == 1:
0187
0188 write_int(1, self._sockfile)
0189 self._sockfile.flush()
0190
0191
0192 self._read_status = read_int(self._sockfile)
0193 if self._read_status == 1:
0194
0195
0196 self._read_iter = self._serializer.load_stream(self._sockfile)
0197 for item in self._read_iter:
0198 yield item
0199
0200
0201 elif self._read_status == -1:
0202 self.jsocket_auth_server.getResult()
0203
0204 def __del__(self):
0205
0206 if self._read_status == 1:
0207 try:
0208
0209 for _ in self._read_iter:
0210 pass
0211
0212 write_int(0, self._sockfile)
0213 self._sockfile.flush()
0214 except Exception:
0215
0216 pass
0217
0218 return iter(PyLocalIterable(sock_info, serializer))
0219
0220
0221 def ignore_unicode_prefix(f):
0222 """
0223 Ignore the 'u' prefix of string in doc tests, to make it works
0224 in both python 2 and 3
0225 """
0226 if sys.version >= '3':
0227
0228
0229 literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE)
0230 f.__doc__ = literal_re.sub(r'\1\2', f.__doc__)
0231 return f
0232
0233
0234 class Partitioner(object):
0235 def __init__(self, numPartitions, partitionFunc):
0236 self.numPartitions = numPartitions
0237 self.partitionFunc = partitionFunc
0238
0239 def __eq__(self, other):
0240 return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
0241 and self.partitionFunc == other.partitionFunc)
0242
0243 def __call__(self, k):
0244 return self.partitionFunc(k) % self.numPartitions
0245
0246
0247 class RDD(object):
0248
0249 """
0250 A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
0251 Represents an immutable, partitioned collection of elements that can be
0252 operated on in parallel.
0253 """
0254
0255 def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())):
0256 self._jrdd = jrdd
0257 self.is_cached = False
0258 self.is_checkpointed = False
0259 self.ctx = ctx
0260 self._jrdd_deserializer = jrdd_deserializer
0261 self._id = jrdd.id()
0262 self.partitioner = None
0263
0264 def _pickled(self):
0265 return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
0266
0267 def id(self):
0268 """
0269 A unique ID for this RDD (within its SparkContext).
0270 """
0271 return self._id
0272
0273 def __repr__(self):
0274 return self._jrdd.toString()
0275
0276 def __getnewargs__(self):
0277
0278 raise Exception(
0279 "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
0280 "action or transformation. RDD transformations and actions can only be invoked by the "
0281 "driver, not inside of other transformations; for example, "
0282 "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
0283 "transformation and count action cannot be performed inside of the rdd1.map "
0284 "transformation. For more information, see SPARK-5063."
0285 )
0286
0287 @property
0288 def context(self):
0289 """
0290 The :class:`SparkContext` that this RDD was created on.
0291 """
0292 return self.ctx
0293
0294 def cache(self):
0295 """
0296 Persist this RDD with the default storage level (`MEMORY_ONLY`).
0297 """
0298 self.is_cached = True
0299 self.persist(StorageLevel.MEMORY_ONLY)
0300 return self
0301
0302 def persist(self, storageLevel=StorageLevel.MEMORY_ONLY):
0303 """
0304 Set this RDD's storage level to persist its values across operations
0305 after the first time it is computed. This can only be used to assign
0306 a new storage level if the RDD does not have a storage level set yet.
0307 If no storage level is specified defaults to (`MEMORY_ONLY`).
0308
0309 >>> rdd = sc.parallelize(["b", "a", "c"])
0310 >>> rdd.persist().is_cached
0311 True
0312 """
0313 self.is_cached = True
0314 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
0315 self._jrdd.persist(javaStorageLevel)
0316 return self
0317
0318 def unpersist(self, blocking=False):
0319 """
0320 Mark the RDD as non-persistent, and remove all blocks for it from
0321 memory and disk.
0322
0323 .. versionchanged:: 3.0.0
0324 Added optional argument `blocking` to specify whether to block until all
0325 blocks are deleted.
0326 """
0327 self.is_cached = False
0328 self._jrdd.unpersist(blocking)
0329 return self
0330
0331 def checkpoint(self):
0332 """
0333 Mark this RDD for checkpointing. It will be saved to a file inside the
0334 checkpoint directory set with :meth:`SparkContext.setCheckpointDir` and
0335 all references to its parent RDDs will be removed. This function must
0336 be called before any job has been executed on this RDD. It is strongly
0337 recommended that this RDD is persisted in memory, otherwise saving it
0338 on a file will require recomputation.
0339 """
0340 self.is_checkpointed = True
0341 self._jrdd.rdd().checkpoint()
0342
0343 def isCheckpointed(self):
0344 """
0345 Return whether this RDD is checkpointed and materialized, either reliably or locally.
0346 """
0347 return self._jrdd.rdd().isCheckpointed()
0348
0349 def localCheckpoint(self):
0350 """
0351 Mark this RDD for local checkpointing using Spark's existing caching layer.
0352
0353 This method is for users who wish to truncate RDD lineages while skipping the expensive
0354 step of replicating the materialized data in a reliable distributed file system. This is
0355 useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX).
0356
0357 Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed
0358 data is written to ephemeral local storage in the executors instead of to a reliable,
0359 fault-tolerant storage. The effect is that if an executor fails during the computation,
0360 the checkpointed data may no longer be accessible, causing an irrecoverable job failure.
0361
0362 This is NOT safe to use with dynamic allocation, which removes executors along
0363 with their cached blocks. If you must use both features, you are advised to set
0364 `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value.
0365
0366 The checkpoint directory set through :meth:`SparkContext.setCheckpointDir` is not used.
0367 """
0368 self._jrdd.rdd().localCheckpoint()
0369
0370 def isLocallyCheckpointed(self):
0371 """
0372 Return whether this RDD is marked for local checkpointing.
0373
0374 Exposed for testing.
0375 """
0376 return self._jrdd.rdd().isLocallyCheckpointed()
0377
0378 def getCheckpointFile(self):
0379 """
0380 Gets the name of the file to which this RDD was checkpointed
0381
0382 Not defined if RDD is checkpointed locally.
0383 """
0384 checkpointFile = self._jrdd.rdd().getCheckpointFile()
0385 if checkpointFile.isDefined():
0386 return checkpointFile.get()
0387
0388 def map(self, f, preservesPartitioning=False):
0389 """
0390 Return a new RDD by applying a function to each element of this RDD.
0391
0392 >>> rdd = sc.parallelize(["b", "a", "c"])
0393 >>> sorted(rdd.map(lambda x: (x, 1)).collect())
0394 [('a', 1), ('b', 1), ('c', 1)]
0395 """
0396 def func(_, iterator):
0397 return map(fail_on_stopiteration(f), iterator)
0398 return self.mapPartitionsWithIndex(func, preservesPartitioning)
0399
0400 def flatMap(self, f, preservesPartitioning=False):
0401 """
0402 Return a new RDD by first applying a function to all elements of this
0403 RDD, and then flattening the results.
0404
0405 >>> rdd = sc.parallelize([2, 3, 4])
0406 >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
0407 [1, 1, 1, 2, 2, 3]
0408 >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
0409 [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
0410 """
0411 def func(s, iterator):
0412 return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
0413 return self.mapPartitionsWithIndex(func, preservesPartitioning)
0414
0415 def mapPartitions(self, f, preservesPartitioning=False):
0416 """
0417 Return a new RDD by applying a function to each partition of this RDD.
0418
0419 >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
0420 >>> def f(iterator): yield sum(iterator)
0421 >>> rdd.mapPartitions(f).collect()
0422 [3, 7]
0423 """
0424 def func(s, iterator):
0425 return f(iterator)
0426 return self.mapPartitionsWithIndex(func, preservesPartitioning)
0427
0428 def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
0429 """
0430 Return a new RDD by applying a function to each partition of this RDD,
0431 while tracking the index of the original partition.
0432
0433 >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
0434 >>> def f(splitIndex, iterator): yield splitIndex
0435 >>> rdd.mapPartitionsWithIndex(f).sum()
0436 6
0437 """
0438 return PipelinedRDD(self, f, preservesPartitioning)
0439
0440 def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
0441 """
0442 Deprecated: use mapPartitionsWithIndex instead.
0443
0444 Return a new RDD by applying a function to each partition of this RDD,
0445 while tracking the index of the original partition.
0446
0447 >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
0448 >>> def f(splitIndex, iterator): yield splitIndex
0449 >>> rdd.mapPartitionsWithSplit(f).sum()
0450 6
0451 """
0452 warnings.warn("mapPartitionsWithSplit is deprecated; "
0453 "use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2)
0454 return self.mapPartitionsWithIndex(f, preservesPartitioning)
0455
0456 def getNumPartitions(self):
0457 """
0458 Returns the number of partitions in RDD
0459
0460 >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
0461 >>> rdd.getNumPartitions()
0462 2
0463 """
0464 return self._jrdd.partitions().size()
0465
0466 def filter(self, f):
0467 """
0468 Return a new RDD containing only the elements that satisfy a predicate.
0469
0470 >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
0471 >>> rdd.filter(lambda x: x % 2 == 0).collect()
0472 [2, 4]
0473 """
0474 def func(iterator):
0475 return filter(fail_on_stopiteration(f), iterator)
0476 return self.mapPartitions(func, True)
0477
0478 def distinct(self, numPartitions=None):
0479 """
0480 Return a new RDD containing the distinct elements in this RDD.
0481
0482 >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
0483 [1, 2, 3]
0484 """
0485 return self.map(lambda x: (x, None)) \
0486 .reduceByKey(lambda x, _: x, numPartitions) \
0487 .map(lambda x: x[0])
0488
0489 def sample(self, withReplacement, fraction, seed=None):
0490 """
0491 Return a sampled subset of this RDD.
0492
0493 :param withReplacement: can elements be sampled multiple times (replaced when sampled out)
0494 :param fraction: expected size of the sample as a fraction of this RDD's size
0495 without replacement: probability that each element is chosen; fraction must be [0, 1]
0496 with replacement: expected number of times each element is chosen; fraction must be >= 0
0497 :param seed: seed for the random number generator
0498
0499 .. note:: This is not guaranteed to provide exactly the fraction specified of the total
0500 count of the given :class:`DataFrame`.
0501
0502 >>> rdd = sc.parallelize(range(100), 4)
0503 >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14
0504 True
0505 """
0506 assert fraction >= 0.0, "Negative fraction value: %s" % fraction
0507 return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
0508
0509 def randomSplit(self, weights, seed=None):
0510 """
0511 Randomly splits this RDD with the provided weights.
0512
0513 :param weights: weights for splits, will be normalized if they don't sum to 1
0514 :param seed: random seed
0515 :return: split RDDs in a list
0516
0517 >>> rdd = sc.parallelize(range(500), 1)
0518 >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
0519 >>> len(rdd1.collect() + rdd2.collect())
0520 500
0521 >>> 150 < rdd1.count() < 250
0522 True
0523 >>> 250 < rdd2.count() < 350
0524 True
0525 """
0526 s = float(sum(weights))
0527 cweights = [0.0]
0528 for w in weights:
0529 cweights.append(cweights[-1] + w / s)
0530 if seed is None:
0531 seed = random.randint(0, 2 ** 32 - 1)
0532 return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
0533 for lb, ub in zip(cweights, cweights[1:])]
0534
0535
0536 def takeSample(self, withReplacement, num, seed=None):
0537 """
0538 Return a fixed-size sampled subset of this RDD.
0539
0540 .. note:: This method should only be used if the resulting array is expected
0541 to be small, as all the data is loaded into the driver's memory.
0542
0543 >>> rdd = sc.parallelize(range(0, 10))
0544 >>> len(rdd.takeSample(True, 20, 1))
0545 20
0546 >>> len(rdd.takeSample(False, 5, 2))
0547 5
0548 >>> len(rdd.takeSample(False, 15, 3))
0549 10
0550 """
0551 numStDev = 10.0
0552
0553 if num < 0:
0554 raise ValueError("Sample size cannot be negative.")
0555 elif num == 0:
0556 return []
0557
0558 initialCount = self.count()
0559 if initialCount == 0:
0560 return []
0561
0562 rand = random.Random(seed)
0563
0564 if (not withReplacement) and num >= initialCount:
0565
0566 samples = self.collect()
0567 rand.shuffle(samples)
0568 return samples
0569
0570 maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize))
0571 if num > maxSampleSize:
0572 raise ValueError(
0573 "Sample size cannot be greater than %d." % maxSampleSize)
0574
0575 fraction = RDD._computeFractionForSampleSize(
0576 num, initialCount, withReplacement)
0577 samples = self.sample(withReplacement, fraction, seed).collect()
0578
0579
0580
0581
0582 while len(samples) < num:
0583
0584 seed = rand.randint(0, sys.maxsize)
0585 samples = self.sample(withReplacement, fraction, seed).collect()
0586
0587 rand.shuffle(samples)
0588
0589 return samples[0:num]
0590
0591 @staticmethod
0592 def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement):
0593 """
0594 Returns a sampling rate that guarantees a sample of
0595 size >= sampleSizeLowerBound 99.99% of the time.
0596
0597 How the sampling rate is determined:
0598 Let p = num / total, where num is the sample size and total is the
0599 total number of data points in the RDD. We're trying to compute
0600 q > p such that
0601 - when sampling with replacement, we're drawing each data point
0602 with prob_i ~ Pois(q), where we want to guarantee
0603 Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to
0604 total), i.e. the failure rate of not having a sufficiently large
0605 sample < 0.0001. Setting q = p + 5 * sqrt(p/total) is sufficient
0606 to guarantee 0.9999 success rate for num > 12, but we need a
0607 slightly larger q (9 empirically determined).
0608 - when sampling without replacement, we're drawing each data point
0609 with prob_i ~ Binomial(total, fraction) and our choice of q
0610 guarantees 1-delta, or 0.9999 success rate, where success rate is
0611 defined the same as in sampling with replacement.
0612 """
0613 fraction = float(sampleSizeLowerBound) / total
0614 if withReplacement:
0615 numStDev = 5
0616 if (sampleSizeLowerBound < 12):
0617 numStDev = 9
0618 return fraction + numStDev * sqrt(fraction / total)
0619 else:
0620 delta = 0.00005
0621 gamma = - log(delta) / total
0622 return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction))
0623
0624 def union(self, other):
0625 """
0626 Return the union of this RDD and another one.
0627
0628 >>> rdd = sc.parallelize([1, 1, 2, 3])
0629 >>> rdd.union(rdd).collect()
0630 [1, 1, 2, 3, 1, 1, 2, 3]
0631 """
0632 if self._jrdd_deserializer == other._jrdd_deserializer:
0633 rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
0634 self._jrdd_deserializer)
0635 else:
0636
0637
0638 self_copy = self._reserialize()
0639 other_copy = other._reserialize()
0640 rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
0641 self.ctx.serializer)
0642 if (self.partitioner == other.partitioner and
0643 self.getNumPartitions() == rdd.getNumPartitions()):
0644 rdd.partitioner = self.partitioner
0645 return rdd
0646
0647 def intersection(self, other):
0648 """
0649 Return the intersection of this RDD and another one. The output will
0650 not contain any duplicate elements, even if the input RDDs did.
0651
0652 .. note:: This method performs a shuffle internally.
0653
0654 >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5])
0655 >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8])
0656 >>> rdd1.intersection(rdd2).collect()
0657 [1, 2, 3]
0658 """
0659 return self.map(lambda v: (v, None)) \
0660 .cogroup(other.map(lambda v: (v, None))) \
0661 .filter(lambda k_vs: all(k_vs[1])) \
0662 .keys()
0663
0664 def _reserialize(self, serializer=None):
0665 serializer = serializer or self.ctx.serializer
0666 if self._jrdd_deserializer != serializer:
0667 self = self.map(lambda x: x, preservesPartitioning=True)
0668 self._jrdd_deserializer = serializer
0669 return self
0670
0671 def __add__(self, other):
0672 """
0673 Return the union of this RDD and another one.
0674
0675 >>> rdd = sc.parallelize([1, 1, 2, 3])
0676 >>> (rdd + rdd).collect()
0677 [1, 1, 2, 3, 1, 1, 2, 3]
0678 """
0679 if not isinstance(other, RDD):
0680 raise TypeError
0681 return self.union(other)
0682
0683 def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
0684 ascending=True, keyfunc=lambda x: x):
0685 """
0686 Repartition the RDD according to the given partitioner and, within each resulting partition,
0687 sort records by their keys.
0688
0689 >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
0690 >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, True)
0691 >>> rdd2.glom().collect()
0692 [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
0693 """
0694 if numPartitions is None:
0695 numPartitions = self._defaultReducePartitions()
0696
0697 memory = self._memory_limit()
0698 serializer = self._jrdd_deserializer
0699
0700 def sortPartition(iterator):
0701 sort = ExternalSorter(memory * 0.9, serializer).sorted
0702 return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending)))
0703
0704 return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
0705
0706 def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
0707 """
0708 Sorts this RDD, which is assumed to consist of (key, value) pairs.
0709
0710 >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
0711 >>> sc.parallelize(tmp).sortByKey().first()
0712 ('1', 3)
0713 >>> sc.parallelize(tmp).sortByKey(True, 1).collect()
0714 [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
0715 >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
0716 [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
0717 >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
0718 >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
0719 >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
0720 [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5),...('white', 9), ('whose', 6)]
0721 """
0722 if numPartitions is None:
0723 numPartitions = self._defaultReducePartitions()
0724
0725 memory = self._memory_limit()
0726 serializer = self._jrdd_deserializer
0727
0728 def sortPartition(iterator):
0729 sort = ExternalSorter(memory * 0.9, serializer).sorted
0730 return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending)))
0731
0732 if numPartitions == 1:
0733 if self.getNumPartitions() > 1:
0734 self = self.coalesce(1)
0735 return self.mapPartitions(sortPartition, True)
0736
0737
0738
0739
0740 rddSize = self.count()
0741 if not rddSize:
0742 return self
0743 maxSampleSize = numPartitions * 20.0
0744 fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
0745 samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect()
0746 samples = sorted(samples, key=keyfunc)
0747
0748
0749
0750 bounds = [samples[int(len(samples) * (i + 1) / numPartitions)]
0751 for i in range(0, numPartitions - 1)]
0752
0753 def rangePartitioner(k):
0754 p = bisect.bisect_left(bounds, keyfunc(k))
0755 if ascending:
0756 return p
0757 else:
0758 return numPartitions - 1 - p
0759
0760 return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True)
0761
0762 def sortBy(self, keyfunc, ascending=True, numPartitions=None):
0763 """
0764 Sorts this RDD by the given keyfunc
0765
0766 >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
0767 >>> sc.parallelize(tmp).sortBy(lambda x: x[0]).collect()
0768 [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
0769 >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect()
0770 [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
0771 """
0772 return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values()
0773
0774 def glom(self):
0775 """
0776 Return an RDD created by coalescing all elements within each partition
0777 into a list.
0778
0779 >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
0780 >>> sorted(rdd.glom().collect())
0781 [[1, 2], [3, 4]]
0782 """
0783 def func(iterator):
0784 yield list(iterator)
0785 return self.mapPartitions(func)
0786
0787 def cartesian(self, other):
0788 """
0789 Return the Cartesian product of this RDD and another one, that is, the
0790 RDD of all pairs of elements ``(a, b)`` where ``a`` is in `self` and
0791 ``b`` is in `other`.
0792
0793 >>> rdd = sc.parallelize([1, 2])
0794 >>> sorted(rdd.cartesian(rdd).collect())
0795 [(1, 1), (1, 2), (2, 1), (2, 2)]
0796 """
0797
0798 deserializer = CartesianDeserializer(self._jrdd_deserializer,
0799 other._jrdd_deserializer)
0800 return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
0801
0802 def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash):
0803 """
0804 Return an RDD of grouped items.
0805
0806 >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
0807 >>> result = rdd.groupBy(lambda x: x % 2).collect()
0808 >>> sorted([(x, sorted(y)) for (x, y) in result])
0809 [(0, [2, 8]), (1, [1, 1, 3, 5])]
0810 """
0811 return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc)
0812
0813 @ignore_unicode_prefix
0814 def pipe(self, command, env=None, checkCode=False):
0815 """
0816 Return an RDD created by piping elements to a forked external process.
0817
0818 >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
0819 [u'1', u'2', u'', u'3']
0820
0821 :param checkCode: whether or not to check the return value of the shell command.
0822 """
0823 if env is None:
0824 env = dict()
0825
0826 def func(iterator):
0827 pipe = Popen(
0828 shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
0829
0830 def pipe_objs(out):
0831 for obj in iterator:
0832 s = unicode(obj).rstrip('\n') + '\n'
0833 out.write(s.encode('utf-8'))
0834 out.close()
0835 Thread(target=pipe_objs, args=[pipe.stdin]).start()
0836
0837 def check_return_code():
0838 pipe.wait()
0839 if checkCode and pipe.returncode:
0840 raise Exception("Pipe function `%s' exited "
0841 "with error code %d" % (command, pipe.returncode))
0842 else:
0843 for i in range(0):
0844 yield i
0845 return (x.rstrip(b'\n').decode('utf-8') for x in
0846 chain(iter(pipe.stdout.readline, b''), check_return_code()))
0847 return self.mapPartitions(func)
0848
0849 def foreach(self, f):
0850 """
0851 Applies a function to all elements of this RDD.
0852
0853 >>> def f(x): print(x)
0854 >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
0855 """
0856 f = fail_on_stopiteration(f)
0857
0858 def processPartition(iterator):
0859 for x in iterator:
0860 f(x)
0861 return iter([])
0862 self.mapPartitions(processPartition).count()
0863
0864 def foreachPartition(self, f):
0865 """
0866 Applies a function to each partition of this RDD.
0867
0868 >>> def f(iterator):
0869 ... for x in iterator:
0870 ... print(x)
0871 >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
0872 """
0873 def func(it):
0874 r = f(it)
0875 try:
0876 return iter(r)
0877 except TypeError:
0878 return iter([])
0879 self.mapPartitions(func).count()
0880
0881 def collect(self):
0882 """
0883 Return a list that contains all of the elements in this RDD.
0884
0885 .. note:: This method should only be used if the resulting array is expected
0886 to be small, as all the data is loaded into the driver's memory.
0887 """
0888 with SCCallSiteSync(self.context) as css:
0889 sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
0890 return list(_load_from_socket(sock_info, self._jrdd_deserializer))
0891
0892 def collectWithJobGroup(self, groupId, description, interruptOnCancel=False):
0893 """
0894 .. note:: Experimental
0895
0896 When collect rdd, use this method to specify job group.
0897
0898 .. versionadded:: 3.0.0
0899 """
0900 with SCCallSiteSync(self.context) as css:
0901 sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup(
0902 self._jrdd.rdd(), groupId, description, interruptOnCancel)
0903 return list(_load_from_socket(sock_info, self._jrdd_deserializer))
0904
0905 def reduce(self, f):
0906 """
0907 Reduces the elements of this RDD using the specified commutative and
0908 associative binary operator. Currently reduces partitions locally.
0909
0910 >>> from operator import add
0911 >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
0912 15
0913 >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
0914 10
0915 >>> sc.parallelize([]).reduce(add)
0916 Traceback (most recent call last):
0917 ...
0918 ValueError: Can not reduce() empty RDD
0919 """
0920 f = fail_on_stopiteration(f)
0921
0922 def func(iterator):
0923 iterator = iter(iterator)
0924 try:
0925 initial = next(iterator)
0926 except StopIteration:
0927 return
0928 yield reduce(f, iterator, initial)
0929
0930 vals = self.mapPartitions(func).collect()
0931 if vals:
0932 return reduce(f, vals)
0933 raise ValueError("Can not reduce() empty RDD")
0934
0935 def treeReduce(self, f, depth=2):
0936 """
0937 Reduces the elements of this RDD in a multi-level tree pattern.
0938
0939 :param depth: suggested depth of the tree (default: 2)
0940
0941 >>> add = lambda x, y: x + y
0942 >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
0943 >>> rdd.treeReduce(add)
0944 -5
0945 >>> rdd.treeReduce(add, 1)
0946 -5
0947 >>> rdd.treeReduce(add, 2)
0948 -5
0949 >>> rdd.treeReduce(add, 5)
0950 -5
0951 >>> rdd.treeReduce(add, 10)
0952 -5
0953 """
0954 if depth < 1:
0955 raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
0956
0957 zeroValue = None, True
0958
0959 def op(x, y):
0960 if x[1]:
0961 return y
0962 elif y[1]:
0963 return x
0964 else:
0965 return f(x[0], y[0]), False
0966
0967 reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth)
0968 if reduced[1]:
0969 raise ValueError("Cannot reduce empty RDD.")
0970 return reduced[0]
0971
0972 def fold(self, zeroValue, op):
0973 """
0974 Aggregate the elements of each partition, and then the results for all
0975 the partitions, using a given associative function and a neutral "zero value."
0976
0977 The function ``op(t1, t2)`` is allowed to modify ``t1`` and return it
0978 as its result value to avoid object allocation; however, it should not
0979 modify ``t2``.
0980
0981 This behaves somewhat differently from fold operations implemented
0982 for non-distributed collections in functional languages like Scala.
0983 This fold operation may be applied to partitions individually, and then
0984 fold those results into the final result, rather than apply the fold
0985 to each element sequentially in some defined ordering. For functions
0986 that are not commutative, the result may differ from that of a fold
0987 applied to a non-distributed collection.
0988
0989 >>> from operator import add
0990 >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
0991 15
0992 """
0993 op = fail_on_stopiteration(op)
0994
0995 def func(iterator):
0996 acc = zeroValue
0997 for obj in iterator:
0998 acc = op(acc, obj)
0999 yield acc
1000
1001
1002
1003 vals = self.mapPartitions(func).collect()
1004 return reduce(op, vals, zeroValue)
1005
1006 def aggregate(self, zeroValue, seqOp, combOp):
1007 """
1008 Aggregate the elements of each partition, and then the results for all
1009 the partitions, using a given combine functions and a neutral "zero
1010 value."
1011
1012 The functions ``op(t1, t2)`` is allowed to modify ``t1`` and return it
1013 as its result value to avoid object allocation; however, it should not
1014 modify ``t2``.
1015
1016 The first function (seqOp) can return a different result type, U, than
1017 the type of this RDD. Thus, we need one operation for merging a T into
1018 an U and one operation for merging two U
1019
1020 >>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1))
1021 >>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1]))
1022 >>> sc.parallelize([1, 2, 3, 4]).aggregate((0, 0), seqOp, combOp)
1023 (10, 4)
1024 >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
1025 (0, 0)
1026 """
1027 seqOp = fail_on_stopiteration(seqOp)
1028 combOp = fail_on_stopiteration(combOp)
1029
1030 def func(iterator):
1031 acc = zeroValue
1032 for obj in iterator:
1033 acc = seqOp(acc, obj)
1034 yield acc
1035
1036
1037
1038 vals = self.mapPartitions(func).collect()
1039 return reduce(combOp, vals, zeroValue)
1040
1041 def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
1042 """
1043 Aggregates the elements of this RDD in a multi-level tree
1044 pattern.
1045
1046 :param depth: suggested depth of the tree (default: 2)
1047
1048 >>> add = lambda x, y: x + y
1049 >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
1050 >>> rdd.treeAggregate(0, add, add)
1051 -5
1052 >>> rdd.treeAggregate(0, add, add, 1)
1053 -5
1054 >>> rdd.treeAggregate(0, add, add, 2)
1055 -5
1056 >>> rdd.treeAggregate(0, add, add, 5)
1057 -5
1058 >>> rdd.treeAggregate(0, add, add, 10)
1059 -5
1060 """
1061 if depth < 1:
1062 raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
1063
1064 if self.getNumPartitions() == 0:
1065 return zeroValue
1066
1067 def aggregatePartition(iterator):
1068 acc = zeroValue
1069 for obj in iterator:
1070 acc = seqOp(acc, obj)
1071 yield acc
1072
1073 partiallyAggregated = self.mapPartitions(aggregatePartition)
1074 numPartitions = partiallyAggregated.getNumPartitions()
1075 scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2)
1076
1077
1078 while numPartitions > scale + numPartitions / scale:
1079 numPartitions /= scale
1080 curNumPartitions = int(numPartitions)
1081
1082 def mapPartition(i, iterator):
1083 for obj in iterator:
1084 yield (i % curNumPartitions, obj)
1085
1086 partiallyAggregated = partiallyAggregated \
1087 .mapPartitionsWithIndex(mapPartition) \
1088 .reduceByKey(combOp, curNumPartitions) \
1089 .values()
1090
1091 return partiallyAggregated.reduce(combOp)
1092
1093 def max(self, key=None):
1094 """
1095 Find the maximum item in this RDD.
1096
1097 :param key: A function used to generate key for comparing
1098
1099 >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
1100 >>> rdd.max()
1101 43.0
1102 >>> rdd.max(key=str)
1103 5.0
1104 """
1105 if key is None:
1106 return self.reduce(max)
1107 return self.reduce(lambda a, b: max(a, b, key=key))
1108
1109 def min(self, key=None):
1110 """
1111 Find the minimum item in this RDD.
1112
1113 :param key: A function used to generate key for comparing
1114
1115 >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
1116 >>> rdd.min()
1117 2.0
1118 >>> rdd.min(key=str)
1119 10.0
1120 """
1121 if key is None:
1122 return self.reduce(min)
1123 return self.reduce(lambda a, b: min(a, b, key=key))
1124
1125 def sum(self):
1126 """
1127 Add up the elements in this RDD.
1128
1129 >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
1130 6.0
1131 """
1132 return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
1133
1134 def count(self):
1135 """
1136 Return the number of elements in this RDD.
1137
1138 >>> sc.parallelize([2, 3, 4]).count()
1139 3
1140 """
1141 return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
1142
1143 def stats(self):
1144 """
1145 Return a :class:`StatCounter` object that captures the mean, variance
1146 and count of the RDD's elements in one operation.
1147 """
1148 def redFunc(left_counter, right_counter):
1149 return left_counter.mergeStats(right_counter)
1150
1151 return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
1152
1153 def histogram(self, buckets):
1154 """
1155 Compute a histogram using the provided buckets. The buckets
1156 are all open to the right except for the last which is closed.
1157 e.g. [1,10,20,50] means the buckets are [1,10) [10,20) [20,50],
1158 which means 1<=x<10, 10<=x<20, 20<=x<=50. And on the input of 1
1159 and 50 we would have a histogram of 1,0,1.
1160
1161 If your histogram is evenly spaced (e.g. [0, 10, 20, 30]),
1162 this can be switched from an O(log n) inseration to O(1) per
1163 element (where n is the number of buckets).
1164
1165 Buckets must be sorted, not contain any duplicates, and have
1166 at least two elements.
1167
1168 If `buckets` is a number, it will generate buckets which are
1169 evenly spaced between the minimum and maximum of the RDD. For
1170 example, if the min value is 0 and the max is 100, given `buckets`
1171 as 2, the resulting buckets will be [0,50) [50,100]. `buckets` must
1172 be at least 1. An exception is raised if the RDD contains infinity.
1173 If the elements in the RDD do not vary (max == min), a single bucket
1174 will be used.
1175
1176 The return value is a tuple of buckets and histogram.
1177
1178 >>> rdd = sc.parallelize(range(51))
1179 >>> rdd.histogram(2)
1180 ([0, 25, 50], [25, 26])
1181 >>> rdd.histogram([0, 5, 25, 50])
1182 ([0, 5, 25, 50], [5, 20, 26])
1183 >>> rdd.histogram([0, 15, 30, 45, 60]) # evenly spaced buckets
1184 ([0, 15, 30, 45, 60], [15, 15, 15, 6])
1185 >>> rdd = sc.parallelize(["ab", "ac", "b", "bd", "ef"])
1186 >>> rdd.histogram(("a", "b", "c"))
1187 (('a', 'b', 'c'), [2, 2])
1188 """
1189
1190 if isinstance(buckets, int):
1191 if buckets < 1:
1192 raise ValueError("number of buckets must be >= 1")
1193
1194
1195 def comparable(x):
1196 if x is None:
1197 return False
1198 if type(x) is float and isnan(x):
1199 return False
1200 return True
1201
1202 filtered = self.filter(comparable)
1203
1204
1205 def minmax(a, b):
1206 return min(a[0], b[0]), max(a[1], b[1])
1207 try:
1208 minv, maxv = filtered.map(lambda x: (x, x)).reduce(minmax)
1209 except TypeError as e:
1210 if " empty " in str(e):
1211 raise ValueError("can not generate buckets from empty RDD")
1212 raise
1213
1214 if minv == maxv or buckets == 1:
1215 return [minv, maxv], [filtered.count()]
1216
1217 try:
1218 inc = (maxv - minv) / buckets
1219 except TypeError:
1220 raise TypeError("Can not generate buckets with non-number in RDD")
1221
1222 if isinf(inc):
1223 raise ValueError("Can not generate buckets with infinite value")
1224
1225
1226 inc = int(inc)
1227 if inc * buckets != maxv - minv:
1228 inc = (maxv - minv) * 1.0 / buckets
1229
1230 buckets = [i * inc + minv for i in range(buckets)]
1231 buckets.append(maxv)
1232 even = True
1233
1234 elif isinstance(buckets, (list, tuple)):
1235 if len(buckets) < 2:
1236 raise ValueError("buckets should have more than one value")
1237
1238 if any(i is None or isinstance(i, float) and isnan(i) for i in buckets):
1239 raise ValueError("can not have None or NaN in buckets")
1240
1241 if sorted(buckets) != list(buckets):
1242 raise ValueError("buckets should be sorted")
1243
1244 if len(set(buckets)) != len(buckets):
1245 raise ValueError("buckets should not contain duplicated values")
1246
1247 minv = buckets[0]
1248 maxv = buckets[-1]
1249 even = False
1250 inc = None
1251 try:
1252 steps = [buckets[i + 1] - buckets[i] for i in range(len(buckets) - 1)]
1253 except TypeError:
1254 pass
1255 else:
1256 if max(steps) - min(steps) < 1e-10:
1257 even = True
1258 inc = (maxv - minv) / (len(buckets) - 1)
1259
1260 else:
1261 raise TypeError("buckets should be a list or tuple or number(int or long)")
1262
1263 def histogram(iterator):
1264 counters = [0] * len(buckets)
1265 for i in iterator:
1266 if i is None or (type(i) is float and isnan(i)) or i > maxv or i < minv:
1267 continue
1268 t = (int((i - minv) / inc) if even
1269 else bisect.bisect_right(buckets, i) - 1)
1270 counters[t] += 1
1271
1272 last = counters.pop()
1273 counters[-1] += last
1274 return [counters]
1275
1276 def mergeCounters(a, b):
1277 return [i + j for i, j in zip(a, b)]
1278
1279 return buckets, self.mapPartitions(histogram).reduce(mergeCounters)
1280
1281 def mean(self):
1282 """
1283 Compute the mean of this RDD's elements.
1284
1285 >>> sc.parallelize([1, 2, 3]).mean()
1286 2.0
1287 """
1288 return self.stats().mean()
1289
1290 def variance(self):
1291 """
1292 Compute the variance of this RDD's elements.
1293
1294 >>> sc.parallelize([1, 2, 3]).variance()
1295 0.666...
1296 """
1297 return self.stats().variance()
1298
1299 def stdev(self):
1300 """
1301 Compute the standard deviation of this RDD's elements.
1302
1303 >>> sc.parallelize([1, 2, 3]).stdev()
1304 0.816...
1305 """
1306 return self.stats().stdev()
1307
1308 def sampleStdev(self):
1309 """
1310 Compute the sample standard deviation of this RDD's elements (which
1311 corrects for bias in estimating the standard deviation by dividing by
1312 N-1 instead of N).
1313
1314 >>> sc.parallelize([1, 2, 3]).sampleStdev()
1315 1.0
1316 """
1317 return self.stats().sampleStdev()
1318
1319 def sampleVariance(self):
1320 """
1321 Compute the sample variance of this RDD's elements (which corrects
1322 for bias in estimating the variance by dividing by N-1 instead of N).
1323
1324 >>> sc.parallelize([1, 2, 3]).sampleVariance()
1325 1.0
1326 """
1327 return self.stats().sampleVariance()
1328
1329 def countByValue(self):
1330 """
1331 Return the count of each unique value in this RDD as a dictionary of
1332 (value, count) pairs.
1333
1334 >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
1335 [(1, 2), (2, 3)]
1336 """
1337 def countPartition(iterator):
1338 counts = defaultdict(int)
1339 for obj in iterator:
1340 counts[obj] += 1
1341 yield counts
1342
1343 def mergeMaps(m1, m2):
1344 for k, v in m2.items():
1345 m1[k] += v
1346 return m1
1347 return self.mapPartitions(countPartition).reduce(mergeMaps)
1348
1349 def top(self, num, key=None):
1350 """
1351 Get the top N elements from an RDD.
1352
1353 .. note:: This method should only be used if the resulting array is expected
1354 to be small, as all the data is loaded into the driver's memory.
1355
1356 .. note:: It returns the list sorted in descending order.
1357
1358 >>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
1359 [12]
1360 >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
1361 [6, 5]
1362 >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
1363 [4, 3, 2]
1364 """
1365 def topIterator(iterator):
1366 yield heapq.nlargest(num, iterator, key=key)
1367
1368 def merge(a, b):
1369 return heapq.nlargest(num, a + b, key=key)
1370
1371 return self.mapPartitions(topIterator).reduce(merge)
1372
1373 def takeOrdered(self, num, key=None):
1374 """
1375 Get the N elements from an RDD ordered in ascending order or as
1376 specified by the optional key function.
1377
1378 .. note:: this method should only be used if the resulting array is expected
1379 to be small, as all the data is loaded into the driver's memory.
1380
1381 >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
1382 [1, 2, 3, 4, 5, 6]
1383 >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
1384 [10, 9, 7, 6, 5, 4]
1385 """
1386
1387 def merge(a, b):
1388 return heapq.nsmallest(num, a + b, key)
1389
1390 return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge)
1391
1392 def take(self, num):
1393 """
1394 Take the first num elements of the RDD.
1395
1396 It works by first scanning one partition, and use the results from
1397 that partition to estimate the number of additional partitions needed
1398 to satisfy the limit.
1399
1400 Translated from the Scala implementation in RDD#take().
1401
1402 .. note:: this method should only be used if the resulting array is expected
1403 to be small, as all the data is loaded into the driver's memory.
1404
1405 >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
1406 [2, 3]
1407 >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
1408 [2, 3, 4, 5, 6]
1409 >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
1410 [91, 92, 93]
1411 """
1412 items = []
1413 totalParts = self.getNumPartitions()
1414 partsScanned = 0
1415
1416 while len(items) < num and partsScanned < totalParts:
1417
1418
1419
1420 numPartsToTry = 1
1421 if partsScanned > 0:
1422
1423
1424
1425
1426 if len(items) == 0:
1427 numPartsToTry = partsScanned * 4
1428 else:
1429
1430 numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned
1431 numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4)
1432
1433 left = num - len(items)
1434
1435 def takeUpToNumLeft(iterator):
1436 iterator = iter(iterator)
1437 taken = 0
1438 while taken < left:
1439 try:
1440 yield next(iterator)
1441 except StopIteration:
1442 return
1443 taken += 1
1444
1445 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
1446 res = self.context.runJob(self, takeUpToNumLeft, p)
1447
1448 items += res
1449 partsScanned += numPartsToTry
1450
1451 return items[:num]
1452
1453 def first(self):
1454 """
1455 Return the first element in this RDD.
1456
1457 >>> sc.parallelize([2, 3, 4]).first()
1458 2
1459 >>> sc.parallelize([]).first()
1460 Traceback (most recent call last):
1461 ...
1462 ValueError: RDD is empty
1463 """
1464 rs = self.take(1)
1465 if rs:
1466 return rs[0]
1467 raise ValueError("RDD is empty")
1468
1469 def isEmpty(self):
1470 """
1471 Returns true if and only if the RDD contains no elements at all.
1472
1473 .. note:: an RDD may be empty even when it has at least 1 partition.
1474
1475 >>> sc.parallelize([]).isEmpty()
1476 True
1477 >>> sc.parallelize([1]).isEmpty()
1478 False
1479 """
1480 return self.getNumPartitions() == 0 or len(self.take(1)) == 0
1481
1482 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
1483 """
1484 Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file
1485 system, using the new Hadoop OutputFormat API (mapreduce package). Keys/values are
1486 converted for output using either user specified converters or, by default,
1487 "org.apache.spark.api.python.JavaToWritableConverter".
1488
1489 :param conf: Hadoop job configuration, passed in as a dict
1490 :param keyConverter: (None by default)
1491 :param valueConverter: (None by default)
1492 """
1493 jconf = self.ctx._dictToJavaMap(conf)
1494 pickledRDD = self._pickled()
1495 self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf,
1496 keyConverter, valueConverter, True)
1497
1498 def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
1499 keyConverter=None, valueConverter=None, conf=None):
1500 """
1501 Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file
1502 system, using the new Hadoop OutputFormat API (mapreduce package). Key and value types
1503 will be inferred if not specified. Keys and values are converted for output using either
1504 user specified converters or "org.apache.spark.api.python.JavaToWritableConverter". The
1505 `conf` is applied on top of the base Hadoop conf associated with the SparkContext
1506 of this RDD to create a merged Hadoop MapReduce job configuration for saving the data.
1507
1508 :param path: path to Hadoop file
1509 :param outputFormatClass: fully qualified classname of Hadoop OutputFormat
1510 (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")
1511 :param keyClass: fully qualified classname of key Writable class
1512 (e.g. "org.apache.hadoop.io.IntWritable", None by default)
1513 :param valueClass: fully qualified classname of value Writable class
1514 (e.g. "org.apache.hadoop.io.Text", None by default)
1515 :param keyConverter: (None by default)
1516 :param valueConverter: (None by default)
1517 :param conf: Hadoop job configuration, passed in as a dict (None by default)
1518 """
1519 jconf = self.ctx._dictToJavaMap(conf)
1520 pickledRDD = self._pickled()
1521 self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path,
1522 outputFormatClass,
1523 keyClass, valueClass,
1524 keyConverter, valueConverter, jconf)
1525
1526 def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
1527 """
1528 Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file
1529 system, using the old Hadoop OutputFormat API (mapred package). Keys/values are
1530 converted for output using either user specified converters or, by default,
1531 "org.apache.spark.api.python.JavaToWritableConverter".
1532
1533 :param conf: Hadoop job configuration, passed in as a dict
1534 :param keyConverter: (None by default)
1535 :param valueConverter: (None by default)
1536 """
1537 jconf = self.ctx._dictToJavaMap(conf)
1538 pickledRDD = self._pickled()
1539 self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf,
1540 keyConverter, valueConverter, False)
1541
1542 def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
1543 keyConverter=None, valueConverter=None, conf=None,
1544 compressionCodecClass=None):
1545 """
1546 Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file
1547 system, using the old Hadoop OutputFormat API (mapred package). Key and value types
1548 will be inferred if not specified. Keys and values are converted for output using either
1549 user specified converters or "org.apache.spark.api.python.JavaToWritableConverter". The
1550 `conf` is applied on top of the base Hadoop conf associated with the SparkContext
1551 of this RDD to create a merged Hadoop MapReduce job configuration for saving the data.
1552
1553 :param path: path to Hadoop file
1554 :param outputFormatClass: fully qualified classname of Hadoop OutputFormat
1555 (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat")
1556 :param keyClass: fully qualified classname of key Writable class
1557 (e.g. "org.apache.hadoop.io.IntWritable", None by default)
1558 :param valueClass: fully qualified classname of value Writable class
1559 (e.g. "org.apache.hadoop.io.Text", None by default)
1560 :param keyConverter: (None by default)
1561 :param valueConverter: (None by default)
1562 :param conf: (None by default)
1563 :param compressionCodecClass: (None by default)
1564 """
1565 jconf = self.ctx._dictToJavaMap(conf)
1566 pickledRDD = self._pickled()
1567 self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path,
1568 outputFormatClass,
1569 keyClass, valueClass,
1570 keyConverter, valueConverter,
1571 jconf, compressionCodecClass)
1572
1573 def saveAsSequenceFile(self, path, compressionCodecClass=None):
1574 """
1575 Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to any Hadoop file
1576 system, using the "org.apache.hadoop.io.Writable" types that we convert from the
1577 RDD's key and value types. The mechanism is as follows:
1578
1579 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects.
1580 2. Keys and values of this Java RDD are converted to Writables and written out.
1581
1582 :param path: path to sequence file
1583 :param compressionCodecClass: (None by default)
1584 """
1585 pickledRDD = self._pickled()
1586 self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True,
1587 path, compressionCodecClass)
1588
1589 def saveAsPickleFile(self, path, batchSize=10):
1590 """
1591 Save this RDD as a SequenceFile of serialized objects. The serializer
1592 used is :class:`pyspark.serializers.PickleSerializer`, default batch size
1593 is 10.
1594
1595 >>> tmpFile = NamedTemporaryFile(delete=True)
1596 >>> tmpFile.close()
1597 >>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3)
1598 >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect())
1599 ['1', '2', 'rdd', 'spark']
1600 """
1601 if batchSize == 0:
1602 ser = AutoBatchedSerializer(PickleSerializer())
1603 else:
1604 ser = BatchedSerializer(PickleSerializer(), batchSize)
1605 self._reserialize(ser)._jrdd.saveAsObjectFile(path)
1606
1607 @ignore_unicode_prefix
1608 def saveAsTextFile(self, path, compressionCodecClass=None):
1609 """
1610 Save this RDD as a text file, using string representations of elements.
1611
1612 :param path: path to text file
1613 :param compressionCodecClass: (None by default) string i.e.
1614 "org.apache.hadoop.io.compress.GzipCodec"
1615
1616 >>> tempFile = NamedTemporaryFile(delete=True)
1617 >>> tempFile.close()
1618 >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
1619 >>> from fileinput import input
1620 >>> from glob import glob
1621 >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*"))))
1622 '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
1623
1624 Empty lines are tolerated when saving to text files.
1625
1626 >>> tempFile2 = NamedTemporaryFile(delete=True)
1627 >>> tempFile2.close()
1628 >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name)
1629 >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*"))))
1630 '\\n\\n\\nbar\\nfoo\\n'
1631
1632 Using compressionCodecClass
1633
1634 >>> tempFile3 = NamedTemporaryFile(delete=True)
1635 >>> tempFile3.close()
1636 >>> codec = "org.apache.hadoop.io.compress.GzipCodec"
1637 >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec)
1638 >>> from fileinput import input, hook_compressed
1639 >>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))
1640 >>> b''.join(result).decode('utf-8')
1641 u'bar\\nfoo\\n'
1642 """
1643 def func(split, iterator):
1644 for x in iterator:
1645 if not isinstance(x, (unicode, bytes)):
1646 x = unicode(x)
1647 if isinstance(x, unicode):
1648 x = x.encode("utf-8")
1649 yield x
1650 keyed = self.mapPartitionsWithIndex(func)
1651 keyed._bypass_serializer = True
1652 if compressionCodecClass:
1653 compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass)
1654 keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec)
1655 else:
1656 keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
1657
1658
1659
1660 def collectAsMap(self):
1661 """
1662 Return the key-value pairs in this RDD to the master as a dictionary.
1663
1664 .. note:: this method should only be used if the resulting data is expected
1665 to be small, as all the data is loaded into the driver's memory.
1666
1667 >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
1668 >>> m[1]
1669 2
1670 >>> m[3]
1671 4
1672 """
1673 return dict(self.collect())
1674
1675 def keys(self):
1676 """
1677 Return an RDD with the keys of each tuple.
1678
1679 >>> m = sc.parallelize([(1, 2), (3, 4)]).keys()
1680 >>> m.collect()
1681 [1, 3]
1682 """
1683 return self.map(lambda x: x[0])
1684
1685 def values(self):
1686 """
1687 Return an RDD with the values of each tuple.
1688
1689 >>> m = sc.parallelize([(1, 2), (3, 4)]).values()
1690 >>> m.collect()
1691 [2, 4]
1692 """
1693 return self.map(lambda x: x[1])
1694
1695 def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash):
1696 """
1697 Merge the values for each key using an associative and commutative reduce function.
1698
1699 This will also perform the merging locally on each mapper before
1700 sending results to a reducer, similarly to a "combiner" in MapReduce.
1701
1702 Output will be partitioned with `numPartitions` partitions, or
1703 the default parallelism level if `numPartitions` is not specified.
1704 Default partitioner is hash-partition.
1705
1706 >>> from operator import add
1707 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
1708 >>> sorted(rdd.reduceByKey(add).collect())
1709 [('a', 2), ('b', 1)]
1710 """
1711 return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc)
1712
1713 def reduceByKeyLocally(self, func):
1714 """
1715 Merge the values for each key using an associative and commutative reduce function, but
1716 return the results immediately to the master as a dictionary.
1717
1718 This will also perform the merging locally on each mapper before
1719 sending results to a reducer, similarly to a "combiner" in MapReduce.
1720
1721 >>> from operator import add
1722 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
1723 >>> sorted(rdd.reduceByKeyLocally(add).items())
1724 [('a', 2), ('b', 1)]
1725 """
1726 func = fail_on_stopiteration(func)
1727
1728 def reducePartition(iterator):
1729 m = {}
1730 for k, v in iterator:
1731 m[k] = func(m[k], v) if k in m else v
1732 yield m
1733
1734 def mergeMaps(m1, m2):
1735 for k, v in m2.items():
1736 m1[k] = func(m1[k], v) if k in m1 else v
1737 return m1
1738 return self.mapPartitions(reducePartition).reduce(mergeMaps)
1739
1740 def countByKey(self):
1741 """
1742 Count the number of elements for each key, and return the result to the
1743 master as a dictionary.
1744
1745 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
1746 >>> sorted(rdd.countByKey().items())
1747 [('a', 2), ('b', 1)]
1748 """
1749 return self.map(lambda x: x[0]).countByValue()
1750
1751 def join(self, other, numPartitions=None):
1752 """
1753 Return an RDD containing all pairs of elements with matching keys in
1754 `self` and `other`.
1755
1756 Each pair of elements will be returned as a (k, (v1, v2)) tuple, where
1757 (k, v1) is in `self` and (k, v2) is in `other`.
1758
1759 Performs a hash join across the cluster.
1760
1761 >>> x = sc.parallelize([("a", 1), ("b", 4)])
1762 >>> y = sc.parallelize([("a", 2), ("a", 3)])
1763 >>> sorted(x.join(y).collect())
1764 [('a', (1, 2)), ('a', (1, 3))]
1765 """
1766 return python_join(self, other, numPartitions)
1767
1768 def leftOuterJoin(self, other, numPartitions=None):
1769 """
1770 Perform a left outer join of `self` and `other`.
1771
1772 For each element (k, v) in `self`, the resulting RDD will either
1773 contain all pairs (k, (v, w)) for w in `other`, or the pair
1774 (k, (v, None)) if no elements in `other` have key k.
1775
1776 Hash-partitions the resulting RDD into the given number of partitions.
1777
1778 >>> x = sc.parallelize([("a", 1), ("b", 4)])
1779 >>> y = sc.parallelize([("a", 2)])
1780 >>> sorted(x.leftOuterJoin(y).collect())
1781 [('a', (1, 2)), ('b', (4, None))]
1782 """
1783 return python_left_outer_join(self, other, numPartitions)
1784
1785 def rightOuterJoin(self, other, numPartitions=None):
1786 """
1787 Perform a right outer join of `self` and `other`.
1788
1789 For each element (k, w) in `other`, the resulting RDD will either
1790 contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w))
1791 if no elements in `self` have key k.
1792
1793 Hash-partitions the resulting RDD into the given number of partitions.
1794
1795 >>> x = sc.parallelize([("a", 1), ("b", 4)])
1796 >>> y = sc.parallelize([("a", 2)])
1797 >>> sorted(y.rightOuterJoin(x).collect())
1798 [('a', (2, 1)), ('b', (None, 4))]
1799 """
1800 return python_right_outer_join(self, other, numPartitions)
1801
1802 def fullOuterJoin(self, other, numPartitions=None):
1803 """
1804 Perform a right outer join of `self` and `other`.
1805
1806 For each element (k, v) in `self`, the resulting RDD will either
1807 contain all pairs (k, (v, w)) for w in `other`, or the pair
1808 (k, (v, None)) if no elements in `other` have key k.
1809
1810 Similarly, for each element (k, w) in `other`, the resulting RDD will
1811 either contain all pairs (k, (v, w)) for v in `self`, or the pair
1812 (k, (None, w)) if no elements in `self` have key k.
1813
1814 Hash-partitions the resulting RDD into the given number of partitions.
1815
1816 >>> x = sc.parallelize([("a", 1), ("b", 4)])
1817 >>> y = sc.parallelize([("a", 2), ("c", 8)])
1818 >>> sorted(x.fullOuterJoin(y).collect())
1819 [('a', (1, 2)), ('b', (4, None)), ('c', (None, 8))]
1820 """
1821 return python_full_outer_join(self, other, numPartitions)
1822
1823
1824
1825
1826 def partitionBy(self, numPartitions, partitionFunc=portable_hash):
1827 """
1828 Return a copy of the RDD partitioned using the specified partitioner.
1829
1830 >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
1831 >>> sets = pairs.partitionBy(2).glom().collect()
1832 >>> len(set(sets[0]).intersection(set(sets[1])))
1833 0
1834 """
1835 if numPartitions is None:
1836 numPartitions = self._defaultReducePartitions()
1837 partitioner = Partitioner(numPartitions, partitionFunc)
1838 if self.partitioner == partitioner:
1839 return self
1840
1841
1842
1843
1844
1845
1846
1847 outputSerializer = self.ctx._unbatched_serializer
1848
1849 limit = (self._memory_limit() / 2)
1850
1851 def add_shuffle_key(split, iterator):
1852
1853 buckets = defaultdict(list)
1854 c, batch = 0, min(10 * numPartitions, 1000)
1855
1856 for k, v in iterator:
1857 buckets[partitionFunc(k) % numPartitions].append((k, v))
1858 c += 1
1859
1860
1861 if (c % 1000 == 0 and get_used_memory() > limit
1862 or c > batch):
1863 n, size = len(buckets), 0
1864 for split in list(buckets.keys()):
1865 yield pack_long(split)
1866 d = outputSerializer.dumps(buckets[split])
1867 del buckets[split]
1868 yield d
1869 size += len(d)
1870
1871 avg = int(size / n) >> 20
1872
1873 if avg < 1:
1874 batch *= 1.5
1875 elif avg > 10:
1876 batch = max(int(batch / 1.5), 1)
1877 c = 0
1878
1879 for split, items in buckets.items():
1880 yield pack_long(split)
1881 yield outputSerializer.dumps(items)
1882
1883 keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True)
1884 keyed._bypass_serializer = True
1885 with SCCallSiteSync(self.context) as css:
1886 pairRDD = self.ctx._jvm.PairwiseRDD(
1887 keyed._jrdd.rdd()).asJavaPairRDD()
1888 jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
1889 id(partitionFunc))
1890 jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
1891 rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
1892 rdd.partitioner = partitioner
1893 return rdd
1894
1895
1896 def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
1897 numPartitions=None, partitionFunc=portable_hash):
1898 """
1899 Generic function to combine the elements for each key using a custom
1900 set of aggregation functions.
1901
1902 Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined
1903 type" C.
1904
1905 Users provide three functions:
1906
1907 - `createCombiner`, which turns a V into a C (e.g., creates
1908 a one-element list)
1909 - `mergeValue`, to merge a V into a C (e.g., adds it to the end of
1910 a list)
1911 - `mergeCombiners`, to combine two C's into a single one (e.g., merges
1912 the lists)
1913
1914 To avoid memory allocation, both mergeValue and mergeCombiners are allowed to
1915 modify and return their first argument instead of creating a new C.
1916
1917 In addition, users can control the partitioning of the output RDD.
1918
1919 .. note:: V and C can be different -- for example, one might group an RDD of type
1920 (Int, Int) into an RDD of type (Int, List[Int]).
1921
1922 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 2)])
1923 >>> def to_list(a):
1924 ... return [a]
1925 ...
1926 >>> def append(a, b):
1927 ... a.append(b)
1928 ... return a
1929 ...
1930 >>> def extend(a, b):
1931 ... a.extend(b)
1932 ... return a
1933 ...
1934 >>> sorted(x.combineByKey(to_list, append, extend).collect())
1935 [('a', [1, 2]), ('b', [1])]
1936 """
1937 if numPartitions is None:
1938 numPartitions = self._defaultReducePartitions()
1939
1940 serializer = self.ctx.serializer
1941 memory = self._memory_limit()
1942 agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
1943
1944 def combineLocally(iterator):
1945 merger = ExternalMerger(agg, memory * 0.9, serializer)
1946 merger.mergeValues(iterator)
1947 return merger.items()
1948
1949 locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
1950 shuffled = locally_combined.partitionBy(numPartitions, partitionFunc)
1951
1952 def _mergeCombiners(iterator):
1953 merger = ExternalMerger(agg, memory, serializer)
1954 merger.mergeCombiners(iterator)
1955 return merger.items()
1956
1957 return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
1958
1959 def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None,
1960 partitionFunc=portable_hash):
1961 """
1962 Aggregate the values of each key, using given combine functions and a neutral
1963 "zero value". This function can return a different result type, U, than the type
1964 of the values in this RDD, V. Thus, we need one operation for merging a V into
1965 a U and one operation for merging two U's, The former operation is used for merging
1966 values within a partition, and the latter is used for merging values between
1967 partitions. To avoid memory allocation, both of these functions are
1968 allowed to modify and return their first argument instead of creating a new U.
1969 """
1970 def createZero():
1971 return copy.deepcopy(zeroValue)
1972
1973 return self.combineByKey(
1974 lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc)
1975
1976 def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash):
1977 """
1978 Merge the values for each key using an associative function "func"
1979 and a neutral "zeroValue" which may be added to the result an
1980 arbitrary number of times, and must not change the result
1981 (e.g., 0 for addition, or 1 for multiplication.).
1982
1983 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
1984 >>> from operator import add
1985 >>> sorted(rdd.foldByKey(0, add).collect())
1986 [('a', 2), ('b', 1)]
1987 """
1988 def createZero():
1989 return copy.deepcopy(zeroValue)
1990
1991 return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions,
1992 partitionFunc)
1993
1994 def _memory_limit(self):
1995 return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
1996
1997
1998 def groupByKey(self, numPartitions=None, partitionFunc=portable_hash):
1999 """
2000 Group the values for each key in the RDD into a single sequence.
2001 Hash-partitions the resulting RDD with numPartitions partitions.
2002
2003 .. note:: If you are grouping in order to perform an aggregation (such as a
2004 sum or average) over each key, using reduceByKey or aggregateByKey will
2005 provide much better performance.
2006
2007 >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
2008 >>> sorted(rdd.groupByKey().mapValues(len).collect())
2009 [('a', 2), ('b', 1)]
2010 >>> sorted(rdd.groupByKey().mapValues(list).collect())
2011 [('a', [1, 1]), ('b', [1])]
2012 """
2013 def createCombiner(x):
2014 return [x]
2015
2016 def mergeValue(xs, x):
2017 xs.append(x)
2018 return xs
2019
2020 def mergeCombiners(a, b):
2021 a.extend(b)
2022 return a
2023
2024 memory = self._memory_limit()
2025 serializer = self._jrdd_deserializer
2026 agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
2027
2028 def combine(iterator):
2029 merger = ExternalMerger(agg, memory * 0.9, serializer)
2030 merger.mergeValues(iterator)
2031 return merger.items()
2032
2033 locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
2034 shuffled = locally_combined.partitionBy(numPartitions, partitionFunc)
2035
2036 def groupByKey(it):
2037 merger = ExternalGroupBy(agg, memory, serializer)
2038 merger.mergeCombiners(it)
2039 return merger.items()
2040
2041 return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
2042
2043 def flatMapValues(self, f):
2044 """
2045 Pass each value in the key-value pair RDD through a flatMap function
2046 without changing the keys; this also retains the original RDD's
2047 partitioning.
2048
2049 >>> x = sc.parallelize([("a", ["x", "y", "z"]), ("b", ["p", "r"])])
2050 >>> def f(x): return x
2051 >>> x.flatMapValues(f).collect()
2052 [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')]
2053 """
2054 flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1]))
2055 return self.flatMap(flat_map_fn, preservesPartitioning=True)
2056
2057 def mapValues(self, f):
2058 """
2059 Pass each value in the key-value pair RDD through a map function
2060 without changing the keys; this also retains the original RDD's
2061 partitioning.
2062
2063 >>> x = sc.parallelize([("a", ["apple", "banana", "lemon"]), ("b", ["grapes"])])
2064 >>> def f(x): return len(x)
2065 >>> x.mapValues(f).collect()
2066 [('a', 3), ('b', 1)]
2067 """
2068 map_values_fn = lambda kv: (kv[0], f(kv[1]))
2069 return self.map(map_values_fn, preservesPartitioning=True)
2070
2071 def groupWith(self, other, *others):
2072 """
2073 Alias for cogroup but with support for multiple RDDs.
2074
2075 >>> w = sc.parallelize([("a", 5), ("b", 6)])
2076 >>> x = sc.parallelize([("a", 1), ("b", 4)])
2077 >>> y = sc.parallelize([("a", 2)])
2078 >>> z = sc.parallelize([("b", 42)])
2079 >>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))]
2080 [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
2081
2082 """
2083 return python_cogroup((self, other) + others, numPartitions=None)
2084
2085
2086 def cogroup(self, other, numPartitions=None):
2087 """
2088 For each key k in `self` or `other`, return a resulting RDD that
2089 contains a tuple with the list of values for that key in `self` as
2090 well as `other`.
2091
2092 >>> x = sc.parallelize([("a", 1), ("b", 4)])
2093 >>> y = sc.parallelize([("a", 2)])
2094 >>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))]
2095 [('a', ([1], [2])), ('b', ([4], []))]
2096 """
2097 return python_cogroup((self, other), numPartitions)
2098
2099 def sampleByKey(self, withReplacement, fractions, seed=None):
2100 """
2101 Return a subset of this RDD sampled by key (via stratified sampling).
2102 Create a sample of this RDD using variable sampling rates for
2103 different keys as specified by fractions, a key to sampling rate map.
2104
2105 >>> fractions = {"a": 0.2, "b": 0.1}
2106 >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000)))
2107 >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect())
2108 >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150
2109 True
2110 >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0
2111 True
2112 >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0
2113 True
2114 """
2115 for fraction in fractions.values():
2116 assert fraction >= 0.0, "Negative fraction value: %s" % fraction
2117 return self.mapPartitionsWithIndex(
2118 RDDStratifiedSampler(withReplacement, fractions, seed).func, True)
2119
2120 def subtractByKey(self, other, numPartitions=None):
2121 """
2122 Return each (key, value) pair in `self` that has no pair with matching
2123 key in `other`.
2124
2125 >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)])
2126 >>> y = sc.parallelize([("a", 3), ("c", None)])
2127 >>> sorted(x.subtractByKey(y).collect())
2128 [('b', 4), ('b', 5)]
2129 """
2130 def filter_func(pair):
2131 key, (val1, val2) = pair
2132 return val1 and not val2
2133 return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0])
2134
2135 def subtract(self, other, numPartitions=None):
2136 """
2137 Return each value in `self` that is not contained in `other`.
2138
2139 >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 3)])
2140 >>> y = sc.parallelize([("a", 3), ("c", None)])
2141 >>> sorted(x.subtract(y).collect())
2142 [('a', 1), ('b', 4), ('b', 5)]
2143 """
2144
2145 rdd = other.map(lambda x: (x, True))
2146 return self.map(lambda x: (x, True)).subtractByKey(rdd, numPartitions).keys()
2147
2148 def keyBy(self, f):
2149 """
2150 Creates tuples of the elements in this RDD by applying `f`.
2151
2152 >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
2153 >>> y = sc.parallelize(zip(range(0,5), range(0,5)))
2154 >>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())]
2155 [(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])]
2156 """
2157 return self.map(lambda x: (f(x), x))
2158
2159 def repartition(self, numPartitions):
2160 """
2161 Return a new RDD that has exactly numPartitions partitions.
2162
2163 Can increase or decrease the level of parallelism in this RDD.
2164 Internally, this uses a shuffle to redistribute data.
2165 If you are decreasing the number of partitions in this RDD, consider
2166 using `coalesce`, which can avoid performing a shuffle.
2167
2168 >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4)
2169 >>> sorted(rdd.glom().collect())
2170 [[1], [2, 3], [4, 5], [6, 7]]
2171 >>> len(rdd.repartition(2).glom().collect())
2172 2
2173 >>> len(rdd.repartition(10).glom().collect())
2174 10
2175 """
2176 return self.coalesce(numPartitions, shuffle=True)
2177
2178 def coalesce(self, numPartitions, shuffle=False):
2179 """
2180 Return a new RDD that is reduced into `numPartitions` partitions.
2181
2182 >>> sc.parallelize([1, 2, 3, 4, 5], 3).glom().collect()
2183 [[1], [2, 3], [4, 5]]
2184 >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect()
2185 [[1, 2, 3, 4, 5]]
2186 """
2187 if shuffle:
2188
2189
2190 batchSize = min(10, self.ctx._batchSize or 1024)
2191 ser = BatchedSerializer(PickleSerializer(), batchSize)
2192 selfCopy = self._reserialize(ser)
2193 jrdd_deserializer = selfCopy._jrdd_deserializer
2194 jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle)
2195 else:
2196 jrdd_deserializer = self._jrdd_deserializer
2197 jrdd = self._jrdd.coalesce(numPartitions, shuffle)
2198 return RDD(jrdd, self.ctx, jrdd_deserializer)
2199
2200 def zip(self, other):
2201 """
2202 Zips this RDD with another one, returning key-value pairs with the
2203 first element in each RDD second element in each RDD, etc. Assumes
2204 that the two RDDs have the same number of partitions and the same
2205 number of elements in each partition (e.g. one was made through
2206 a map on the other).
2207
2208 >>> x = sc.parallelize(range(0,5))
2209 >>> y = sc.parallelize(range(1000, 1005))
2210 >>> x.zip(y).collect()
2211 [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
2212 """
2213 def get_batch_size(ser):
2214 if isinstance(ser, BatchedSerializer):
2215 return ser.batchSize
2216 return 1
2217
2218 def batch_as(rdd, batchSize):
2219 return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
2220
2221 my_batch = get_batch_size(self._jrdd_deserializer)
2222 other_batch = get_batch_size(other._jrdd_deserializer)
2223 if my_batch != other_batch or not my_batch:
2224
2225 batchSize = min(my_batch, other_batch)
2226 if batchSize <= 0:
2227
2228 batchSize = 100
2229 other = batch_as(other, batchSize)
2230 self = batch_as(self, batchSize)
2231
2232 if self.getNumPartitions() != other.getNumPartitions():
2233 raise ValueError("Can only zip with RDD which has the same number of partitions")
2234
2235
2236
2237 pairRDD = self._jrdd.zip(other._jrdd)
2238 deserializer = PairDeserializer(self._jrdd_deserializer,
2239 other._jrdd_deserializer)
2240 return RDD(pairRDD, self.ctx, deserializer)
2241
2242 def zipWithIndex(self):
2243 """
2244 Zips this RDD with its element indices.
2245
2246 The ordering is first based on the partition index and then the
2247 ordering of items within each partition. So the first item in
2248 the first partition gets index 0, and the last item in the last
2249 partition receives the largest index.
2250
2251 This method needs to trigger a spark job when this RDD contains
2252 more than one partitions.
2253
2254 >>> sc.parallelize(["a", "b", "c", "d"], 3).zipWithIndex().collect()
2255 [('a', 0), ('b', 1), ('c', 2), ('d', 3)]
2256 """
2257 starts = [0]
2258 if self.getNumPartitions() > 1:
2259 nums = self.mapPartitions(lambda it: [sum(1 for i in it)]).collect()
2260 for i in range(len(nums) - 1):
2261 starts.append(starts[-1] + nums[i])
2262
2263 def func(k, it):
2264 for i, v in enumerate(it, starts[k]):
2265 yield v, i
2266
2267 return self.mapPartitionsWithIndex(func)
2268
2269 def zipWithUniqueId(self):
2270 """
2271 Zips this RDD with generated unique Long ids.
2272
2273 Items in the kth partition will get ids k, n+k, 2*n+k, ..., where
2274 n is the number of partitions. So there may exist gaps, but this
2275 method won't trigger a spark job, which is different from
2276 :meth:`zipWithIndex`.
2277
2278 >>> sc.parallelize(["a", "b", "c", "d", "e"], 3).zipWithUniqueId().collect()
2279 [('a', 0), ('b', 1), ('c', 4), ('d', 2), ('e', 5)]
2280 """
2281 n = self.getNumPartitions()
2282
2283 def func(k, it):
2284 for i, v in enumerate(it):
2285 yield v, i * n + k
2286
2287 return self.mapPartitionsWithIndex(func)
2288
2289 def name(self):
2290 """
2291 Return the name of this RDD.
2292 """
2293 n = self._jrdd.name()
2294 if n:
2295 return n
2296
2297 @ignore_unicode_prefix
2298 def setName(self, name):
2299 """
2300 Assign a name to this RDD.
2301
2302 >>> rdd1 = sc.parallelize([1, 2])
2303 >>> rdd1.setName('RDD1').name()
2304 u'RDD1'
2305 """
2306 self._jrdd.setName(name)
2307 return self
2308
2309 def toDebugString(self):
2310 """
2311 A description of this RDD and its recursive dependencies for debugging.
2312 """
2313 debug_string = self._jrdd.toDebugString()
2314 if debug_string:
2315 return debug_string.encode('utf-8')
2316
2317 def getStorageLevel(self):
2318 """
2319 Get the RDD's current storage level.
2320
2321 >>> rdd1 = sc.parallelize([1,2])
2322 >>> rdd1.getStorageLevel()
2323 StorageLevel(False, False, False, False, 1)
2324 >>> print(rdd1.getStorageLevel())
2325 Serialized 1x Replicated
2326 """
2327 java_storage_level = self._jrdd.getStorageLevel()
2328 storage_level = StorageLevel(java_storage_level.useDisk(),
2329 java_storage_level.useMemory(),
2330 java_storage_level.useOffHeap(),
2331 java_storage_level.deserialized(),
2332 java_storage_level.replication())
2333 return storage_level
2334
2335 def _defaultReducePartitions(self):
2336 """
2337 Returns the default number of partitions to use during reduce tasks (e.g., groupBy).
2338 If spark.default.parallelism is set, then we'll use the value from SparkContext
2339 defaultParallelism, otherwise we'll use the number of partitions in this RDD.
2340
2341 This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce
2342 the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will
2343 be inherent.
2344 """
2345 if self.ctx._conf.contains("spark.default.parallelism"):
2346 return self.ctx.defaultParallelism
2347 else:
2348 return self.getNumPartitions()
2349
2350 def lookup(self, key):
2351 """
2352 Return the list of values in the RDD for key `key`. This operation
2353 is done efficiently if the RDD has a known partitioner by only
2354 searching the partition that the key maps to.
2355
2356 >>> l = range(1000)
2357 >>> rdd = sc.parallelize(zip(l, l), 10)
2358 >>> rdd.lookup(42) # slow
2359 [42]
2360 >>> sorted = rdd.sortByKey()
2361 >>> sorted.lookup(42) # fast
2362 [42]
2363 >>> sorted.lookup(1024)
2364 []
2365 >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey()
2366 >>> list(rdd2.lookup(('a', 'b'))[0])
2367 ['c']
2368 """
2369 values = self.filter(lambda kv: kv[0] == key).values()
2370
2371 if self.partitioner is not None:
2372 return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)])
2373
2374 return values.collect()
2375
2376 def _to_java_object_rdd(self):
2377 """ Return a JavaRDD of Object by unpickling
2378
2379 It will convert each Python object into Java object by Pyrolite, whenever the
2380 RDD is serialized in batch or not.
2381 """
2382 rdd = self._pickled()
2383 return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True)
2384
2385 def countApprox(self, timeout, confidence=0.95):
2386 """
2387 Approximate version of count() that returns a potentially incomplete
2388 result within a timeout, even if not all tasks have finished.
2389
2390 >>> rdd = sc.parallelize(range(1000), 10)
2391 >>> rdd.countApprox(1000, 1.0)
2392 1000
2393 """
2394 drdd = self.mapPartitions(lambda it: [float(sum(1 for i in it))])
2395 return int(drdd.sumApprox(timeout, confidence))
2396
2397 def sumApprox(self, timeout, confidence=0.95):
2398 """
2399 Approximate operation to return the sum within a timeout
2400 or meet the confidence.
2401
2402 >>> rdd = sc.parallelize(range(1000), 10)
2403 >>> r = sum(range(1000))
2404 >>> abs(rdd.sumApprox(1000) - r) / r < 0.05
2405 True
2406 """
2407 jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
2408 jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
2409 r = jdrdd.sumApprox(timeout, confidence).getFinalValue()
2410 return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
2411
2412 def meanApprox(self, timeout, confidence=0.95):
2413 """
2414 Approximate operation to return the mean within a timeout
2415 or meet the confidence.
2416
2417 >>> rdd = sc.parallelize(range(1000), 10)
2418 >>> r = sum(range(1000)) / 1000.0
2419 >>> abs(rdd.meanApprox(1000) - r) / r < 0.05
2420 True
2421 """
2422 jrdd = self.map(float)._to_java_object_rdd()
2423 jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
2424 r = jdrdd.meanApprox(timeout, confidence).getFinalValue()
2425 return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
2426
2427 def countApproxDistinct(self, relativeSD=0.05):
2428 """
2429 Return approximate number of distinct elements in the RDD.
2430
2431 The algorithm used is based on streamlib's implementation of
2432 `"HyperLogLog in Practice: Algorithmic Engineering of a State
2433 of The Art Cardinality Estimation Algorithm", available here
2434 <https://doi.org/10.1145/2452376.2452456>`_.
2435
2436 :param relativeSD: Relative accuracy. Smaller values create
2437 counters that require more space.
2438 It must be greater than 0.000017.
2439
2440 >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct()
2441 >>> 900 < n < 1100
2442 True
2443 >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct()
2444 >>> 16 < n < 24
2445 True
2446 """
2447 if relativeSD < 0.000017:
2448 raise ValueError("relativeSD should be greater than 0.000017")
2449
2450 hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
2451 return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
2452
2453 def toLocalIterator(self, prefetchPartitions=False):
2454 """
2455 Return an iterator that contains all of the elements in this RDD.
2456 The iterator will consume as much memory as the largest partition in this RDD.
2457 With prefetch it may consume up to the memory of the 2 largest partitions.
2458
2459 :param prefetchPartitions: If Spark should pre-fetch the next partition
2460 before it is needed.
2461
2462 >>> rdd = sc.parallelize(range(10))
2463 >>> [x for x in rdd.toLocalIterator()]
2464 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
2465 """
2466 with SCCallSiteSync(self.context) as css:
2467 sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
2468 self._jrdd.rdd(),
2469 prefetchPartitions)
2470 return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)
2471
2472 def barrier(self):
2473 """
2474 .. note:: Experimental
2475
2476 Marks the current stage as a barrier stage, where Spark must launch all tasks together.
2477 In case of a task failure, instead of only restarting the failed task, Spark will abort the
2478 entire stage and relaunch all tasks for this stage.
2479 The barrier execution mode feature is experimental and it only handles limited scenarios.
2480 Please read the linked SPIP and design docs to understand the limitations and future plans.
2481
2482 :return: an :class:`RDDBarrier` instance that provides actions within a barrier stage.
2483
2484 .. seealso:: :class:`BarrierTaskContext`
2485 .. seealso:: `SPIP: Barrier Execution Mode
2486 <http://jira.apache.org/jira/browse/SPARK-24374>`_
2487 .. seealso:: `Design Doc <https://jira.apache.org/jira/browse/SPARK-24582>`_
2488
2489 .. versionadded:: 2.4.0
2490 """
2491 return RDDBarrier(self)
2492
2493 def _is_barrier(self):
2494 """
2495 Whether this RDD is in a barrier stage.
2496 """
2497 return self._jrdd.rdd().isBarrier()
2498
2499
2500 def _prepare_for_python_RDD(sc, command):
2501
2502 ser = CloudPickleSerializer()
2503 pickled_command = ser.dumps(command)
2504 if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc):
2505
2506 broadcast = sc.broadcast(pickled_command)
2507 pickled_command = ser.dumps(broadcast)
2508 broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_vars]
2509 sc._pickled_broadcast_vars.clear()
2510 return pickled_command, broadcast_vars, sc.environment, sc._python_includes
2511
2512
2513 def _wrap_function(sc, func, deserializer, serializer, profiler=None):
2514 assert deserializer, "deserializer should not be empty"
2515 assert serializer, "serializer should not be empty"
2516 command = (func, profiler, deserializer, serializer)
2517 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
2518 return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
2519 sc.pythonVer, broadcast_vars, sc._javaAccumulator)
2520
2521
2522 class RDDBarrier(object):
2523
2524 """
2525 .. note:: Experimental
2526
2527 Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together.
2528 :class:`RDDBarrier` instances are created by :func:`RDD.barrier`.
2529
2530 .. versionadded:: 2.4.0
2531 """
2532
2533 def __init__(self, rdd):
2534 self.rdd = rdd
2535
2536 def mapPartitions(self, f, preservesPartitioning=False):
2537 """
2538 .. note:: Experimental
2539
2540 Returns a new RDD by applying a function to each partition of the wrapped RDD,
2541 where tasks are launched together in a barrier stage.
2542 The interface is the same as :func:`RDD.mapPartitions`.
2543 Please see the API doc there.
2544
2545 .. versionadded:: 2.4.0
2546 """
2547 def func(s, iterator):
2548 return f(iterator)
2549 return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True)
2550
2551 def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
2552 """
2553 .. note:: Experimental
2554
2555 Returns a new RDD by applying a function to each partition of the wrapped RDD, while
2556 tracking the index of the original partition. And all tasks are launched together
2557 in a barrier stage.
2558 The interface is the same as :func:`RDD.mapPartitionsWithIndex`.
2559 Please see the API doc there.
2560
2561 .. versionadded:: 3.0.0
2562 """
2563 return PipelinedRDD(self.rdd, f, preservesPartitioning, isFromBarrier=True)
2564
2565
2566 class PipelinedRDD(RDD):
2567
2568 """
2569 Pipelined maps:
2570
2571 >>> rdd = sc.parallelize([1, 2, 3, 4])
2572 >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
2573 [4, 8, 12, 16]
2574 >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
2575 [4, 8, 12, 16]
2576
2577 Pipelined reduces:
2578 >>> from operator import add
2579 >>> rdd.map(lambda x: 2 * x).reduce(add)
2580 20
2581 >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
2582 20
2583 """
2584
2585 def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False):
2586 if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
2587
2588 self.func = func
2589 self.preservesPartitioning = preservesPartitioning
2590 self._prev_jrdd = prev._jrdd
2591 self._prev_jrdd_deserializer = prev._jrdd_deserializer
2592 else:
2593 prev_func = prev.func
2594
2595 def pipeline_func(split, iterator):
2596 return func(split, prev_func(split, iterator))
2597 self.func = pipeline_func
2598 self.preservesPartitioning = \
2599 prev.preservesPartitioning and preservesPartitioning
2600 self._prev_jrdd = prev._prev_jrdd
2601 self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
2602 self.is_cached = False
2603 self.is_checkpointed = False
2604 self.ctx = prev.ctx
2605 self.prev = prev
2606 self._jrdd_val = None
2607 self._id = None
2608 self._jrdd_deserializer = self.ctx.serializer
2609 self._bypass_serializer = False
2610 self.partitioner = prev.partitioner if self.preservesPartitioning else None
2611 self.is_barrier = isFromBarrier or prev._is_barrier()
2612
2613 def getNumPartitions(self):
2614 return self._prev_jrdd.partitions().size()
2615
2616 @property
2617 def _jrdd(self):
2618 if self._jrdd_val:
2619 return self._jrdd_val
2620 if self._bypass_serializer:
2621 self._jrdd_deserializer = NoOpSerializer()
2622
2623 if self.ctx.profiler_collector:
2624 profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
2625 else:
2626 profiler = None
2627
2628 wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
2629 self._jrdd_deserializer, profiler)
2630 python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
2631 self.preservesPartitioning, self.is_barrier)
2632 self._jrdd_val = python_rdd.asJavaRDD()
2633
2634 if profiler:
2635 self._id = self._jrdd_val.id()
2636 self.ctx.profiler_collector.add_profiler(self._id, profiler)
2637 return self._jrdd_val
2638
2639 def id(self):
2640 if self._id is None:
2641 self._id = self._jrdd.id()
2642 return self._id
2643
2644 def _is_pipelinable(self):
2645 return not (self.is_cached or self.is_checkpointed)
2646
2647 def _is_barrier(self):
2648 return self.is_barrier
2649
2650
2651 def _test():
2652 import doctest
2653 from pyspark.context import SparkContext
2654 globs = globals().copy()
2655
2656
2657 globs['sc'] = SparkContext('local[4]', 'PythonTest')
2658 (failure_count, test_count) = doctest.testmod(
2659 globs=globs, optionflags=doctest.ELLIPSIS)
2660 globs['sc'].stop()
2661 if failure_count:
2662 sys.exit(-1)
2663
2664
2665 if __name__ == "__main__":
2666 _test()