0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import shutil
0020 import signal
0021 import sys
0022 import threading
0023 import warnings
0024 from threading import RLock
0025 from tempfile import NamedTemporaryFile
0026
0027 from py4j.protocol import Py4JError
0028 from py4j.java_gateway import is_instance_of
0029
0030 from pyspark import accumulators
0031 from pyspark.accumulators import Accumulator
0032 from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
0033 from pyspark.conf import SparkConf
0034 from pyspark.files import SparkFiles
0035 from pyspark.java_gateway import launch_gateway, local_connect_and_auth
0036 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
0037 PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
0038 from pyspark.storagelevel import StorageLevel
0039 from pyspark.resource import ResourceInformation
0040 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
0041 from pyspark.traceback_utils import CallSite, first_spark_call
0042 from pyspark.status import StatusTracker
0043 from pyspark.profiler import ProfilerCollector, BasicProfiler
0044 from pyspark.util import _warn_pin_thread
0045
0046 if sys.version > '3':
0047 xrange = range
0048
0049
0050 __all__ = ['SparkContext']
0051
0052
0053
0054
0055 DEFAULT_CONFIGS = {
0056 "spark.serializer.objectStreamReset": 100,
0057 "spark.rdd.compress": True,
0058 }
0059
0060
0061 class SparkContext(object):
0062
0063 """
0064 Main entry point for Spark functionality. A SparkContext represents the
0065 connection to a Spark cluster, and can be used to create :class:`RDD` and
0066 broadcast variables on that cluster.
0067
0068 .. note:: Only one :class:`SparkContext` should be active per JVM. You must `stop()`
0069 the active :class:`SparkContext` before creating a new one.
0070
0071 .. note:: :class:`SparkContext` instance is not supported to share across multiple
0072 processes out of the box, and PySpark does not guarantee multi-processing execution.
0073 Use threads instead for concurrent processing purpose.
0074 """
0075
0076 _gateway = None
0077 _jvm = None
0078 _next_accum_id = 0
0079 _active_spark_context = None
0080 _lock = RLock()
0081 _python_includes = None
0082
0083 PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')
0084
0085 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
0086 environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
0087 gateway=None, jsc=None, profiler_cls=BasicProfiler):
0088 """
0089 Create a new SparkContext. At least the master and app name should be set,
0090 either through the named parameters here or through `conf`.
0091
0092 :param master: Cluster URL to connect to
0093 (e.g. mesos://host:port, spark://host:port, local[4]).
0094 :param appName: A name for your job, to display on the cluster web UI.
0095 :param sparkHome: Location where Spark is installed on cluster nodes.
0096 :param pyFiles: Collection of .zip or .py files to send to the cluster
0097 and add to PYTHONPATH. These can be paths on the local file
0098 system or HDFS, HTTP, HTTPS, or FTP URLs.
0099 :param environment: A dictionary of environment variables to set on
0100 worker nodes.
0101 :param batchSize: The number of Python objects represented as a single
0102 Java object. Set 1 to disable batching, 0 to automatically choose
0103 the batch size based on object sizes, or -1 to use an unlimited
0104 batch size
0105 :param serializer: The serializer for RDDs.
0106 :param conf: A :class:`SparkConf` object setting Spark properties.
0107 :param gateway: Use an existing gateway and JVM, otherwise a new JVM
0108 will be instantiated.
0109 :param jsc: The JavaSparkContext instance (optional).
0110 :param profiler_cls: A class of custom Profiler used to do profiling
0111 (default is pyspark.profiler.BasicProfiler).
0112
0113
0114 >>> from pyspark.context import SparkContext
0115 >>> sc = SparkContext('local', 'test')
0116
0117 >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
0118 Traceback (most recent call last):
0119 ...
0120 ValueError:...
0121 """
0122 self._callsite = first_spark_call() or CallSite(None, None, None)
0123 if gateway is not None and gateway.gateway_parameters.auth_token is None:
0124 raise ValueError(
0125 "You are trying to pass an insecure Py4j gateway to Spark. This"
0126 " is not allowed as it is a security risk.")
0127
0128 SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
0129 try:
0130 self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
0131 conf, jsc, profiler_cls)
0132 except:
0133
0134 self.stop()
0135 raise
0136
0137 def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
0138 conf, jsc, profiler_cls):
0139 self.environment = environment or {}
0140
0141 if conf is not None and conf._jconf is not None:
0142
0143
0144
0145 self._conf = conf
0146 else:
0147 self._conf = SparkConf(_jvm=SparkContext._jvm)
0148 if conf is not None:
0149 for k, v in conf.getAll():
0150 self._conf.set(k, v)
0151
0152 self._batchSize = batchSize
0153 self._unbatched_serializer = serializer
0154 if batchSize == 0:
0155 self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
0156 else:
0157 self.serializer = BatchedSerializer(self._unbatched_serializer,
0158 batchSize)
0159
0160
0161 if master:
0162 self._conf.setMaster(master)
0163 if appName:
0164 self._conf.setAppName(appName)
0165 if sparkHome:
0166 self._conf.setSparkHome(sparkHome)
0167 if environment:
0168 for key, value in environment.items():
0169 self._conf.setExecutorEnv(key, value)
0170 for key, value in DEFAULT_CONFIGS.items():
0171 self._conf.setIfMissing(key, value)
0172
0173
0174 if not self._conf.contains("spark.master"):
0175 raise Exception("A master URL must be set in your configuration")
0176 if not self._conf.contains("spark.app.name"):
0177 raise Exception("An application name must be set in your configuration")
0178
0179
0180
0181 self.master = self._conf.get("spark.master")
0182 self.appName = self._conf.get("spark.app.name")
0183 self.sparkHome = self._conf.get("spark.home", None)
0184
0185 for (k, v) in self._conf.getAll():
0186 if k.startswith("spark.executorEnv."):
0187 varName = k[len("spark.executorEnv."):]
0188 self.environment[varName] = v
0189
0190 self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", "0")
0191
0192
0193 self._jsc = jsc or self._initialize_context(self._conf._jconf)
0194
0195 self._conf = SparkConf(_jconf=self._jsc.sc().conf())
0196
0197
0198
0199 auth_token = self._gateway.gateway_parameters.auth_token
0200 self._accumulatorServer = accumulators._start_update_server(auth_token)
0201 (host, port) = self._accumulatorServer.server_address
0202 self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
0203 self._jsc.sc().register(self._javaAccumulator)
0204
0205
0206
0207
0208 self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc)
0209
0210 self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
0211 self.pythonVer = "%d.%d" % sys.version_info[:2]
0212
0213 if sys.version_info < (3, 6):
0214 with warnings.catch_warnings():
0215 warnings.simplefilter("once")
0216 warnings.warn(
0217 "Support for Python 2 and Python 3 prior to version 3.6 is deprecated as "
0218 "of Spark 3.0. See also the plan for dropping Python 2 support at "
0219 "https://spark.apache.org/news/plan-for-dropping-python-2-support.html.",
0220 DeprecationWarning)
0221
0222
0223
0224
0225
0226 self._pickled_broadcast_vars = BroadcastPickleRegistry()
0227
0228 SparkFiles._sc = self
0229 root_dir = SparkFiles.getRootDirectory()
0230 sys.path.insert(1, root_dir)
0231
0232
0233 self._python_includes = list()
0234 for path in (pyFiles or []):
0235 self.addPyFile(path)
0236
0237
0238
0239 for path in self._conf.get("spark.submit.pyFiles", "").split(","):
0240 if path != "":
0241 (dirname, filename) = os.path.split(path)
0242 try:
0243 filepath = os.path.join(SparkFiles.getRootDirectory(), filename)
0244 if not os.path.exists(filepath):
0245
0246
0247
0248 shutil.copyfile(path, filepath)
0249 if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
0250 self._python_includes.append(filename)
0251 sys.path.insert(1, filepath)
0252 except Exception:
0253 warnings.warn(
0254 "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to "
0255 "Python path:\n %s" % (path, "\n ".join(sys.path)),
0256 RuntimeWarning)
0257
0258
0259 local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
0260 self._temp_dir = \
0261 self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \
0262 .getAbsolutePath()
0263
0264
0265 if self._conf.get("spark.python.profile", "false") == "true":
0266 dump_path = self._conf.get("spark.python.profile.dump", None)
0267 self.profiler_collector = ProfilerCollector(profiler_cls, dump_path)
0268 else:
0269 self.profiler_collector = None
0270
0271
0272 def signal_handler(signal, frame):
0273 self.cancelAllJobs()
0274 raise KeyboardInterrupt()
0275
0276
0277 if isinstance(threading.current_thread(), threading._MainThread):
0278 signal.signal(signal.SIGINT, signal_handler)
0279
0280 def __repr__(self):
0281 return "<SparkContext master={master} appName={appName}>".format(
0282 master=self.master,
0283 appName=self.appName,
0284 )
0285
0286 def _repr_html_(self):
0287 return """
0288 <div>
0289 <p><b>SparkContext</b></p>
0290
0291 <p><a href="{sc.uiWebUrl}">Spark UI</a></p>
0292
0293 <dl>
0294 <dt>Version</dt>
0295 <dd><code>v{sc.version}</code></dd>
0296 <dt>Master</dt>
0297 <dd><code>{sc.master}</code></dd>
0298 <dt>AppName</dt>
0299 <dd><code>{sc.appName}</code></dd>
0300 </dl>
0301 </div>
0302 """.format(
0303 sc=self
0304 )
0305
0306 def _initialize_context(self, jconf):
0307 """
0308 Initialize SparkContext in function to allow subclass specific initialization
0309 """
0310 return self._jvm.JavaSparkContext(jconf)
0311
0312 @classmethod
0313 def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
0314 """
0315 Checks whether a SparkContext is initialized or not.
0316 Throws error if a SparkContext is already running.
0317 """
0318 with SparkContext._lock:
0319 if not SparkContext._gateway:
0320 SparkContext._gateway = gateway or launch_gateway(conf)
0321 SparkContext._jvm = SparkContext._gateway.jvm
0322
0323 if instance:
0324 if (SparkContext._active_spark_context and
0325 SparkContext._active_spark_context != instance):
0326 currentMaster = SparkContext._active_spark_context.master
0327 currentAppName = SparkContext._active_spark_context.appName
0328 callsite = SparkContext._active_spark_context._callsite
0329
0330
0331 raise ValueError(
0332 "Cannot run multiple SparkContexts at once; "
0333 "existing SparkContext(app=%s, master=%s)"
0334 " created by %s at %s:%s "
0335 % (currentAppName, currentMaster,
0336 callsite.function, callsite.file, callsite.linenum))
0337 else:
0338 SparkContext._active_spark_context = instance
0339
0340 def __getnewargs__(self):
0341
0342 raise Exception(
0343 "It appears that you are attempting to reference SparkContext from a broadcast "
0344 "variable, action, or transformation. SparkContext can only be used on the driver, "
0345 "not in code that it run on workers. For more information, see SPARK-5063."
0346 )
0347
0348 def __enter__(self):
0349 """
0350 Enable 'with SparkContext(...) as sc: app(sc)' syntax.
0351 """
0352 return self
0353
0354 def __exit__(self, type, value, trace):
0355 """
0356 Enable 'with SparkContext(...) as sc: app' syntax.
0357
0358 Specifically stop the context on exit of the with block.
0359 """
0360 self.stop()
0361
0362 @classmethod
0363 def getOrCreate(cls, conf=None):
0364 """
0365 Get or instantiate a SparkContext and register it as a singleton object.
0366
0367 :param conf: SparkConf (optional)
0368 """
0369 with SparkContext._lock:
0370 if SparkContext._active_spark_context is None:
0371 SparkContext(conf=conf or SparkConf())
0372 return SparkContext._active_spark_context
0373
0374 def setLogLevel(self, logLevel):
0375 """
0376 Control our logLevel. This overrides any user-defined log settings.
0377 Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
0378 """
0379 self._jsc.setLogLevel(logLevel)
0380
0381 @classmethod
0382 def setSystemProperty(cls, key, value):
0383 """
0384 Set a Java system property, such as spark.executor.memory. This must
0385 must be invoked before instantiating SparkContext.
0386 """
0387 SparkContext._ensure_initialized()
0388 SparkContext._jvm.java.lang.System.setProperty(key, value)
0389
0390 @property
0391 def version(self):
0392 """
0393 The version of Spark on which this application is running.
0394 """
0395 return self._jsc.version()
0396
0397 @property
0398 @ignore_unicode_prefix
0399 def applicationId(self):
0400 """
0401 A unique identifier for the Spark application.
0402 Its format depends on the scheduler implementation.
0403
0404 * in case of local spark app something like 'local-1433865536131'
0405 * in case of YARN something like 'application_1433865536131_34483'
0406
0407 >>> sc.applicationId # doctest: +ELLIPSIS
0408 u'local-...'
0409 """
0410 return self._jsc.sc().applicationId()
0411
0412 @property
0413 def uiWebUrl(self):
0414 """Return the URL of the SparkUI instance started by this SparkContext"""
0415 return self._jsc.sc().uiWebUrl().get()
0416
0417 @property
0418 def startTime(self):
0419 """Return the epoch time when the Spark Context was started."""
0420 return self._jsc.startTime()
0421
0422 @property
0423 def defaultParallelism(self):
0424 """
0425 Default level of parallelism to use when not given by user (e.g. for
0426 reduce tasks)
0427 """
0428 return self._jsc.sc().defaultParallelism()
0429
0430 @property
0431 def defaultMinPartitions(self):
0432 """
0433 Default min number of partitions for Hadoop RDDs when not given by user
0434 """
0435 return self._jsc.sc().defaultMinPartitions()
0436
0437 def stop(self):
0438 """
0439 Shut down the SparkContext.
0440 """
0441 if getattr(self, "_jsc", None):
0442 try:
0443 self._jsc.stop()
0444 except Py4JError:
0445
0446 warnings.warn(
0447 'Unable to cleanly shutdown Spark JVM process.'
0448 ' It is possible that the process has crashed,'
0449 ' been killed or may also be in a zombie state.',
0450 RuntimeWarning
0451 )
0452 finally:
0453 self._jsc = None
0454 if getattr(self, "_accumulatorServer", None):
0455 self._accumulatorServer.shutdown()
0456 self._accumulatorServer = None
0457 with SparkContext._lock:
0458 SparkContext._active_spark_context = None
0459
0460 def emptyRDD(self):
0461 """
0462 Create an RDD that has no partitions or elements.
0463 """
0464 return RDD(self._jsc.emptyRDD(), self, NoOpSerializer())
0465
0466 def range(self, start, end=None, step=1, numSlices=None):
0467 """
0468 Create a new RDD of int containing elements from `start` to `end`
0469 (exclusive), increased by `step` every element. Can be called the same
0470 way as python's built-in range() function. If called with a single argument,
0471 the argument is interpreted as `end`, and `start` is set to 0.
0472
0473 :param start: the start value
0474 :param end: the end value (exclusive)
0475 :param step: the incremental step (default: 1)
0476 :param numSlices: the number of partitions of the new RDD
0477 :return: An RDD of int
0478
0479 >>> sc.range(5).collect()
0480 [0, 1, 2, 3, 4]
0481 >>> sc.range(2, 4).collect()
0482 [2, 3]
0483 >>> sc.range(1, 7, 2).collect()
0484 [1, 3, 5]
0485 """
0486 if end is None:
0487 end = start
0488 start = 0
0489
0490 return self.parallelize(xrange(start, end, step), numSlices)
0491
0492 def parallelize(self, c, numSlices=None):
0493 """
0494 Distribute a local Python collection to form an RDD. Using xrange
0495 is recommended if the input represents a range for performance.
0496
0497 >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect()
0498 [[0], [2], [3], [4], [6]]
0499 >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect()
0500 [[], [0], [], [2], [4]]
0501 """
0502 numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
0503 if isinstance(c, xrange):
0504 size = len(c)
0505 if size == 0:
0506 return self.parallelize([], numSlices)
0507 step = c[1] - c[0] if size > 1 else 1
0508 start0 = c[0]
0509
0510 def getStart(split):
0511 return start0 + int((split * size / numSlices)) * step
0512
0513 def f(split, iterator):
0514
0515
0516
0517
0518
0519
0520
0521 assert len(list(iterator)) == 0
0522 return xrange(getStart(split), getStart(split + 1), step)
0523
0524 return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
0525
0526
0527 if "__len__" not in dir(c):
0528 c = list(c)
0529 batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
0530 serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
0531
0532 def reader_func(temp_filename):
0533 return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
0534
0535 def createRDDServer():
0536 return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices)
0537
0538 jrdd = self._serialize_to_jvm(c, serializer, reader_func, createRDDServer)
0539 return RDD(jrdd, self, serializer)
0540
0541 def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer):
0542 """
0543 Using py4j to send a large dataset to the jvm is really slow, so we use either a file
0544 or a socket if we have encryption enabled.
0545 :param data:
0546 :param serializer:
0547 :param reader_func: A function which takes a filename and reads in the data in the jvm and
0548 returns a JavaRDD. Only used when encryption is disabled.
0549 :param createRDDServer: A function which creates a PythonRDDServer in the jvm to
0550 accept the serialized data, for use when encryption is enabled.
0551 :return:
0552 """
0553 if self._encryption_enabled:
0554
0555 server = createRDDServer()
0556 (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
0557 chunked_out = ChunkedStream(sock_file, 8192)
0558 serializer.dump_stream(data, chunked_out)
0559 chunked_out.close()
0560
0561
0562 r = server.getResult()
0563 return r
0564 else:
0565
0566
0567 tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
0568 try:
0569 try:
0570 serializer.dump_stream(data, tempFile)
0571 finally:
0572 tempFile.close()
0573 return reader_func(tempFile.name)
0574 finally:
0575
0576 os.unlink(tempFile.name)
0577
0578 def pickleFile(self, name, minPartitions=None):
0579 """
0580 Load an RDD previously saved using :meth:`RDD.saveAsPickleFile` method.
0581
0582 >>> tmpFile = NamedTemporaryFile(delete=True)
0583 >>> tmpFile.close()
0584 >>> sc.parallelize(range(10)).saveAsPickleFile(tmpFile.name, 5)
0585 >>> sorted(sc.pickleFile(tmpFile.name, 3).collect())
0586 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
0587 """
0588 minPartitions = minPartitions or self.defaultMinPartitions
0589 return RDD(self._jsc.objectFile(name, minPartitions), self)
0590
0591 @ignore_unicode_prefix
0592 def textFile(self, name, minPartitions=None, use_unicode=True):
0593 """
0594 Read a text file from HDFS, a local file system (available on all
0595 nodes), or any Hadoop-supported file system URI, and return it as an
0596 RDD of Strings.
0597 The text files must be encoded as UTF-8.
0598
0599 If use_unicode is False, the strings will be kept as `str` (encoding
0600 as `utf-8`), which is faster and smaller than unicode. (Added in
0601 Spark 1.2)
0602
0603 >>> path = os.path.join(tempdir, "sample-text.txt")
0604 >>> with open(path, "w") as testFile:
0605 ... _ = testFile.write("Hello world!")
0606 >>> textFile = sc.textFile(path)
0607 >>> textFile.collect()
0608 [u'Hello world!']
0609 """
0610 minPartitions = minPartitions or min(self.defaultParallelism, 2)
0611 return RDD(self._jsc.textFile(name, minPartitions), self,
0612 UTF8Deserializer(use_unicode))
0613
0614 @ignore_unicode_prefix
0615 def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
0616 """
0617 Read a directory of text files from HDFS, a local file system
0618 (available on all nodes), or any Hadoop-supported file system
0619 URI. Each file is read as a single record and returned in a
0620 key-value pair, where the key is the path of each file, the
0621 value is the content of each file.
0622 The text files must be encoded as UTF-8.
0623
0624 If use_unicode is False, the strings will be kept as `str` (encoding
0625 as `utf-8`), which is faster and smaller than unicode. (Added in
0626 Spark 1.2)
0627
0628 For example, if you have the following files:
0629
0630 .. code-block:: text
0631
0632 hdfs://a-hdfs-path/part-00000
0633 hdfs://a-hdfs-path/part-00001
0634 ...
0635 hdfs://a-hdfs-path/part-nnnnn
0636
0637 Do ``rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")``,
0638 then ``rdd`` contains:
0639
0640 .. code-block:: text
0641
0642 (a-hdfs-path/part-00000, its content)
0643 (a-hdfs-path/part-00001, its content)
0644 ...
0645 (a-hdfs-path/part-nnnnn, its content)
0646
0647 .. note:: Small files are preferred, as each file will be loaded
0648 fully in memory.
0649
0650 >>> dirPath = os.path.join(tempdir, "files")
0651 >>> os.mkdir(dirPath)
0652 >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
0653 ... _ = file1.write("1")
0654 >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
0655 ... _ = file2.write("2")
0656 >>> textFiles = sc.wholeTextFiles(dirPath)
0657 >>> sorted(textFiles.collect())
0658 [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
0659 """
0660 minPartitions = minPartitions or self.defaultMinPartitions
0661 return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
0662 PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)))
0663
0664 def binaryFiles(self, path, minPartitions=None):
0665 """
0666 Read a directory of binary files from HDFS, a local file system
0667 (available on all nodes), or any Hadoop-supported file system URI
0668 as a byte array. Each file is read as a single record and returned
0669 in a key-value pair, where the key is the path of each file, the
0670 value is the content of each file.
0671
0672 .. note:: Small files are preferred, large file is also allowable, but
0673 may cause bad performance.
0674 """
0675 minPartitions = minPartitions or self.defaultMinPartitions
0676 return RDD(self._jsc.binaryFiles(path, minPartitions), self,
0677 PairDeserializer(UTF8Deserializer(), NoOpSerializer()))
0678
0679 def binaryRecords(self, path, recordLength):
0680 """
0681 Load data from a flat binary file, assuming each record is a set of numbers
0682 with the specified numerical format (see ByteBuffer), and the number of
0683 bytes per record is constant.
0684
0685 :param path: Directory to the input data files
0686 :param recordLength: The length at which to split the records
0687 """
0688 return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer())
0689
0690 def _dictToJavaMap(self, d):
0691 jm = self._jvm.java.util.HashMap()
0692 if not d:
0693 d = {}
0694 for k, v in d.items():
0695 jm[k] = v
0696 return jm
0697
0698 def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
0699 valueConverter=None, minSplits=None, batchSize=0):
0700 """
0701 Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
0702 a local file system (available on all nodes), or any Hadoop-supported file system URI.
0703 The mechanism is as follows:
0704
0705 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key
0706 and value Writable classes
0707 2. Serialization is attempted via Pyrolite pickling
0708 3. If this fails, the fallback is to call 'toString' on each key and value
0709 4. :class:`PickleSerializer` is used to deserialize pickled objects on the Python side
0710
0711 :param path: path to sequncefile
0712 :param keyClass: fully qualified classname of key Writable class
0713 (e.g. "org.apache.hadoop.io.Text")
0714 :param valueClass: fully qualified classname of value Writable class
0715 (e.g. "org.apache.hadoop.io.LongWritable")
0716 :param keyConverter:
0717 :param valueConverter:
0718 :param minSplits: minimum splits in dataset
0719 (default min(2, sc.defaultParallelism))
0720 :param batchSize: The number of Python objects represented as a single
0721 Java object. (default 0, choose batchSize automatically)
0722 """
0723 minSplits = minSplits or min(self.defaultParallelism, 2)
0724 jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
0725 keyConverter, valueConverter, minSplits, batchSize)
0726 return RDD(jrdd, self)
0727
0728 def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
0729 valueConverter=None, conf=None, batchSize=0):
0730 """
0731 Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS,
0732 a local file system (available on all nodes), or any Hadoop-supported file system URI.
0733 The mechanism is the same as for sc.sequenceFile.
0734
0735 A Hadoop configuration can be passed in as a Python dict. This will be converted into a
0736 Configuration in Java
0737
0738 :param path: path to Hadoop file
0739 :param inputFormatClass: fully qualified classname of Hadoop InputFormat
0740 (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat")
0741 :param keyClass: fully qualified classname of key Writable class
0742 (e.g. "org.apache.hadoop.io.Text")
0743 :param valueClass: fully qualified classname of value Writable class
0744 (e.g. "org.apache.hadoop.io.LongWritable")
0745 :param keyConverter: (None by default)
0746 :param valueConverter: (None by default)
0747 :param conf: Hadoop configuration, passed in as a dict
0748 (None by default)
0749 :param batchSize: The number of Python objects represented as a single
0750 Java object. (default 0, choose batchSize automatically)
0751 """
0752 jconf = self._dictToJavaMap(conf)
0753 jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
0754 valueClass, keyConverter, valueConverter,
0755 jconf, batchSize)
0756 return RDD(jrdd, self)
0757
0758 def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
0759 valueConverter=None, conf=None, batchSize=0):
0760 """
0761 Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
0762 Hadoop configuration, which is passed in as a Python dict.
0763 This will be converted into a Configuration in Java.
0764 The mechanism is the same as for sc.sequenceFile.
0765
0766 :param inputFormatClass: fully qualified classname of Hadoop InputFormat
0767 (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat")
0768 :param keyClass: fully qualified classname of key Writable class
0769 (e.g. "org.apache.hadoop.io.Text")
0770 :param valueClass: fully qualified classname of value Writable class
0771 (e.g. "org.apache.hadoop.io.LongWritable")
0772 :param keyConverter: (None by default)
0773 :param valueConverter: (None by default)
0774 :param conf: Hadoop configuration, passed in as a dict
0775 (None by default)
0776 :param batchSize: The number of Python objects represented as a single
0777 Java object. (default 0, choose batchSize automatically)
0778 """
0779 jconf = self._dictToJavaMap(conf)
0780 jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
0781 valueClass, keyConverter, valueConverter,
0782 jconf, batchSize)
0783 return RDD(jrdd, self)
0784
0785 def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
0786 valueConverter=None, conf=None, batchSize=0):
0787 """
0788 Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS,
0789 a local file system (available on all nodes), or any Hadoop-supported file system URI.
0790 The mechanism is the same as for sc.sequenceFile.
0791
0792 A Hadoop configuration can be passed in as a Python dict. This will be converted into a
0793 Configuration in Java.
0794
0795 :param path: path to Hadoop file
0796 :param inputFormatClass: fully qualified classname of Hadoop InputFormat
0797 (e.g. "org.apache.hadoop.mapred.TextInputFormat")
0798 :param keyClass: fully qualified classname of key Writable class
0799 (e.g. "org.apache.hadoop.io.Text")
0800 :param valueClass: fully qualified classname of value Writable class
0801 (e.g. "org.apache.hadoop.io.LongWritable")
0802 :param keyConverter: (None by default)
0803 :param valueConverter: (None by default)
0804 :param conf: Hadoop configuration, passed in as a dict
0805 (None by default)
0806 :param batchSize: The number of Python objects represented as a single
0807 Java object. (default 0, choose batchSize automatically)
0808 """
0809 jconf = self._dictToJavaMap(conf)
0810 jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
0811 valueClass, keyConverter, valueConverter,
0812 jconf, batchSize)
0813 return RDD(jrdd, self)
0814
0815 def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
0816 valueConverter=None, conf=None, batchSize=0):
0817 """
0818 Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
0819 Hadoop configuration, which is passed in as a Python dict.
0820 This will be converted into a Configuration in Java.
0821 The mechanism is the same as for sc.sequenceFile.
0822
0823 :param inputFormatClass: fully qualified classname of Hadoop InputFormat
0824 (e.g. "org.apache.hadoop.mapred.TextInputFormat")
0825 :param keyClass: fully qualified classname of key Writable class
0826 (e.g. "org.apache.hadoop.io.Text")
0827 :param valueClass: fully qualified classname of value Writable class
0828 (e.g. "org.apache.hadoop.io.LongWritable")
0829 :param keyConverter: (None by default)
0830 :param valueConverter: (None by default)
0831 :param conf: Hadoop configuration, passed in as a dict
0832 (None by default)
0833 :param batchSize: The number of Python objects represented as a single
0834 Java object. (default 0, choose batchSize automatically)
0835 """
0836 jconf = self._dictToJavaMap(conf)
0837 jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass,
0838 valueClass, keyConverter, valueConverter,
0839 jconf, batchSize)
0840 return RDD(jrdd, self)
0841
0842 def _checkpointFile(self, name, input_deserializer):
0843 jrdd = self._jsc.checkpointFile(name)
0844 return RDD(jrdd, self, input_deserializer)
0845
0846 @ignore_unicode_prefix
0847 def union(self, rdds):
0848 """
0849 Build the union of a list of RDDs.
0850
0851 This supports unions() of RDDs with different serialized formats,
0852 although this forces them to be reserialized using the default
0853 serializer:
0854
0855 >>> path = os.path.join(tempdir, "union-text.txt")
0856 >>> with open(path, "w") as testFile:
0857 ... _ = testFile.write("Hello")
0858 >>> textFile = sc.textFile(path)
0859 >>> textFile.collect()
0860 [u'Hello']
0861 >>> parallelized = sc.parallelize(["World!"])
0862 >>> sorted(sc.union([textFile, parallelized]).collect())
0863 [u'Hello', 'World!']
0864 """
0865 first_jrdd_deserializer = rdds[0]._jrdd_deserializer
0866 if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
0867 rdds = [x._reserialize() for x in rdds]
0868 gw = SparkContext._gateway
0869 jvm = SparkContext._jvm
0870 jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD
0871 jpair_rdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD
0872 jdouble_rdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD
0873 if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls):
0874 cls = jrdd_cls
0875 elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls):
0876 cls = jpair_rdd_cls
0877 elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls):
0878 cls = jdouble_rdd_cls
0879 else:
0880 cls_name = rdds[0]._jrdd.getClass().getCanonicalName()
0881 raise TypeError("Unsupported Java RDD class %s" % cls_name)
0882 jrdds = gw.new_array(cls, len(rdds))
0883 for i in range(0, len(rdds)):
0884 jrdds[i] = rdds[i]._jrdd
0885 return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
0886
0887 def broadcast(self, value):
0888 """
0889 Broadcast a read-only variable to the cluster, returning a :class:`Broadcast`
0890 object for reading it in distributed functions. The variable will
0891 be sent to each cluster only once.
0892 """
0893 return Broadcast(self, value, self._pickled_broadcast_vars)
0894
0895 def accumulator(self, value, accum_param=None):
0896 """
0897 Create an :class:`Accumulator` with the given initial value, using a given
0898 :class:`AccumulatorParam` helper object to define how to add values of the
0899 data type if provided. Default AccumulatorParams are used for integers
0900 and floating-point numbers if you do not provide one. For other types,
0901 a custom AccumulatorParam can be used.
0902 """
0903 if accum_param is None:
0904 if isinstance(value, int):
0905 accum_param = accumulators.INT_ACCUMULATOR_PARAM
0906 elif isinstance(value, float):
0907 accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
0908 elif isinstance(value, complex):
0909 accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
0910 else:
0911 raise TypeError("No default accumulator param for type %s" % type(value))
0912 SparkContext._next_accum_id += 1
0913 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
0914
0915 def addFile(self, path, recursive=False):
0916 """
0917 Add a file to be downloaded with this Spark job on every node.
0918 The `path` passed can be either a local file, a file in HDFS
0919 (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
0920 FTP URI.
0921
0922 To access the file in Spark jobs, use :meth:`SparkFiles.get` with the
0923 filename to find its download location.
0924
0925 A directory can be given if the recursive option is set to True.
0926 Currently directories are only supported for Hadoop-supported filesystems.
0927
0928 .. note:: A path can be added only once. Subsequent additions of the same path are ignored.
0929
0930 >>> from pyspark import SparkFiles
0931 >>> path = os.path.join(tempdir, "test.txt")
0932 >>> with open(path, "w") as testFile:
0933 ... _ = testFile.write("100")
0934 >>> sc.addFile(path)
0935 >>> def func(iterator):
0936 ... with open(SparkFiles.get("test.txt")) as testFile:
0937 ... fileVal = int(testFile.readline())
0938 ... return [x * fileVal for x in iterator]
0939 >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
0940 [100, 200, 300, 400]
0941 """
0942 self._jsc.sc().addFile(path, recursive)
0943
0944 def addPyFile(self, path):
0945 """
0946 Add a .py or .zip dependency for all tasks to be executed on this
0947 SparkContext in the future. The `path` passed can be either a local
0948 file, a file in HDFS (or other Hadoop-supported filesystems), or an
0949 HTTP, HTTPS or FTP URI.
0950
0951 .. note:: A path can be added only once. Subsequent additions of the same path are ignored.
0952 """
0953 self.addFile(path)
0954 (dirname, filename) = os.path.split(path)
0955 if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
0956 self._python_includes.append(filename)
0957
0958 sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
0959 if sys.version > '3':
0960 import importlib
0961 importlib.invalidate_caches()
0962
0963 def setCheckpointDir(self, dirName):
0964 """
0965 Set the directory under which RDDs are going to be checkpointed. The
0966 directory must be an HDFS path if running on a cluster.
0967 """
0968 self._jsc.sc().setCheckpointDir(dirName)
0969
0970 def _getJavaStorageLevel(self, storageLevel):
0971 """
0972 Returns a Java StorageLevel based on a pyspark.StorageLevel.
0973 """
0974 if not isinstance(storageLevel, StorageLevel):
0975 raise Exception("storageLevel must be of type pyspark.StorageLevel")
0976
0977 newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
0978 return newStorageLevel(storageLevel.useDisk,
0979 storageLevel.useMemory,
0980 storageLevel.useOffHeap,
0981 storageLevel.deserialized,
0982 storageLevel.replication)
0983
0984 def setJobGroup(self, groupId, description, interruptOnCancel=False):
0985 """
0986 Assigns a group ID to all the jobs started by this thread until the group ID is set to a
0987 different value or cleared.
0988
0989 Often, a unit of execution in an application consists of multiple Spark actions or jobs.
0990 Application programmers can use this method to group all those jobs together and give a
0991 group description. Once set, the Spark web UI will associate such jobs with this group.
0992
0993 The application can use :meth:`SparkContext.cancelJobGroup` to cancel all
0994 running jobs in this group.
0995
0996 >>> import threading
0997 >>> from time import sleep
0998 >>> result = "Not Set"
0999 >>> lock = threading.Lock()
1000 >>> def map_func(x):
1001 ... sleep(100)
1002 ... raise Exception("Task should have been cancelled")
1003 >>> def start_job(x):
1004 ... global result
1005 ... try:
1006 ... sc.setJobGroup("job_to_cancel", "some description")
1007 ... result = sc.parallelize(range(x)).map(map_func).collect()
1008 ... except Exception as e:
1009 ... result = "Cancelled"
1010 ... lock.release()
1011 >>> def stop_job():
1012 ... sleep(5)
1013 ... sc.cancelJobGroup("job_to_cancel")
1014 >>> suppress = lock.acquire()
1015 >>> suppress = threading.Thread(target=start_job, args=(10,)).start()
1016 >>> suppress = threading.Thread(target=stop_job).start()
1017 >>> suppress = lock.acquire()
1018 >>> print(result)
1019 Cancelled
1020
1021 If interruptOnCancel is set to true for the job group, then job cancellation will result
1022 in Thread.interrupt() being called on the job's executor threads. This is useful to help
1023 ensure that the tasks are actually stopped in a timely manner, but is off by default due
1024 to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead.
1025
1026 .. note:: Currently, setting a group ID (set to local properties) with multiple threads
1027 does not properly work. Internally threads on PVM and JVM are not synced, and JVM
1028 thread can be reused for multiple threads on PVM, which fails to isolate local
1029 properties for each thread on PVM.
1030
1031 To work around this, you can set `PYSPARK_PIN_THREAD` to
1032 `'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
1033 from the parent thread although it isolates each thread on PVM and JVM with its own
1034 local properties.
1035
1036 To work around this, you should manually copy and set the local
1037 properties from the parent thread to the child thread when you create another thread.
1038 """
1039 _warn_pin_thread("setJobGroup")
1040 self._jsc.setJobGroup(groupId, description, interruptOnCancel)
1041
1042 def setLocalProperty(self, key, value):
1043 """
1044 Set a local property that affects jobs submitted from this thread, such as the
1045 Spark fair scheduler pool.
1046
1047 .. note:: Currently, setting a local property with multiple threads does not properly work.
1048 Internally threads on PVM and JVM are not synced, and JVM thread
1049 can be reused for multiple threads on PVM, which fails to isolate local properties
1050 for each thread on PVM.
1051
1052 To work around this, you can set `PYSPARK_PIN_THREAD` to
1053 `'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
1054 from the parent thread although it isolates each thread on PVM and JVM with its own
1055 local properties.
1056
1057 To work around this, you should manually copy and set the local
1058 properties from the parent thread to the child thread when you create another thread.
1059 """
1060 _warn_pin_thread("setLocalProperty")
1061 self._jsc.setLocalProperty(key, value)
1062
1063 def getLocalProperty(self, key):
1064 """
1065 Get a local property set in this thread, or null if it is missing. See
1066 :meth:`setLocalProperty`.
1067 """
1068 return self._jsc.getLocalProperty(key)
1069
1070 def setJobDescription(self, value):
1071 """
1072 Set a human readable description of the current job.
1073
1074 .. note:: Currently, setting a job description (set to local properties) with multiple
1075 threads does not properly work. Internally threads on PVM and JVM are not synced,
1076 and JVM thread can be reused for multiple threads on PVM, which fails to isolate
1077 local properties for each thread on PVM.
1078
1079 To work around this, you can set `PYSPARK_PIN_THREAD` to
1080 `'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
1081 from the parent thread although it isolates each thread on PVM and JVM with its own
1082 local properties.
1083
1084 To work around this, you should manually copy and set the local
1085 properties from the parent thread to the child thread when you create another thread.
1086 """
1087 _warn_pin_thread("setJobDescription")
1088 self._jsc.setJobDescription(value)
1089
1090 def sparkUser(self):
1091 """
1092 Get SPARK_USER for user who is running SparkContext.
1093 """
1094 return self._jsc.sc().sparkUser()
1095
1096 def cancelJobGroup(self, groupId):
1097 """
1098 Cancel active jobs for the specified group. See :meth:`SparkContext.setJobGroup`.
1099 for more information.
1100 """
1101 self._jsc.sc().cancelJobGroup(groupId)
1102
1103 def cancelAllJobs(self):
1104 """
1105 Cancel all jobs that have been scheduled or are running.
1106 """
1107 self._jsc.sc().cancelAllJobs()
1108
1109 def statusTracker(self):
1110 """
1111 Return :class:`StatusTracker` object
1112 """
1113 return StatusTracker(self._jsc.statusTracker())
1114
1115 def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
1116 """
1117 Executes the given partitionFunc on the specified set of partitions,
1118 returning the result as an array of elements.
1119
1120 If 'partitions' is not specified, this will run over all partitions.
1121
1122 >>> myRDD = sc.parallelize(range(6), 3)
1123 >>> sc.runJob(myRDD, lambda part: [x * x for x in part])
1124 [0, 1, 4, 9, 16, 25]
1125
1126 >>> myRDD = sc.parallelize(range(6), 3)
1127 >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
1128 [0, 1, 16, 25]
1129 """
1130 if partitions is None:
1131 partitions = range(rdd._jrdd.partitions().size())
1132
1133
1134
1135
1136 mappedRDD = rdd.mapPartitions(partitionFunc)
1137 sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
1138 return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
1139
1140 def show_profiles(self):
1141 """ Print the profile stats to stdout """
1142 if self.profiler_collector is not None:
1143 self.profiler_collector.show_profiles()
1144 else:
1145 raise RuntimeError("'spark.python.profile' configuration must be set "
1146 "to 'true' to enable Python profile.")
1147
1148 def dump_profiles(self, path):
1149 """ Dump the profile stats into directory `path`
1150 """
1151 if self.profiler_collector is not None:
1152 self.profiler_collector.dump_profiles(path)
1153 else:
1154 raise RuntimeError("'spark.python.profile' configuration must be set "
1155 "to 'true' to enable Python profile.")
1156
1157 def getConf(self):
1158 conf = SparkConf()
1159 conf.setAll(self._conf.getAll())
1160 return conf
1161
1162 @property
1163 def resources(self):
1164 resources = {}
1165 jresources = self._jsc.resources()
1166 for x in jresources:
1167 name = jresources[x].name()
1168 jaddresses = jresources[x].addresses()
1169 addrs = [addr for addr in jaddresses]
1170 resources[name] = ResourceInformation(name, addrs)
1171 return resources
1172
1173
1174 def _test():
1175 import atexit
1176 import doctest
1177 import tempfile
1178 globs = globals().copy()
1179 globs['sc'] = SparkContext('local[4]', 'PythonTest')
1180 globs['tempdir'] = tempfile.mkdtemp()
1181 atexit.register(lambda: shutil.rmtree(globs['tempdir']))
1182 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
1183 globs['sc'].stop()
1184 if failure_count:
1185 sys.exit(-1)
1186
1187
1188 if __name__ == "__main__":
1189 _test()