0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import print_function
0019
0020 from py4j.java_gateway import java_import, is_instance_of
0021
0022 from pyspark import RDD, SparkConf
0023 from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer
0024 from pyspark.context import SparkContext
0025 from pyspark.storagelevel import StorageLevel
0026 from pyspark.streaming.dstream import DStream
0027 from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer
0028
0029 __all__ = ["StreamingContext"]
0030
0031
0032 class StreamingContext(object):
0033 """
0034 Main entry point for Spark Streaming functionality. A StreamingContext
0035 represents the connection to a Spark cluster, and can be used to create
0036 :class:`DStream` various input sources. It can be from an existing :class:`SparkContext`.
0037 After creating and transforming DStreams, the streaming computation can
0038 be started and stopped using `context.start()` and `context.stop()`,
0039 respectively. `context.awaitTermination()` allows the current thread
0040 to wait for the termination of the context by `stop()` or by an exception.
0041 """
0042 _transformerSerializer = None
0043
0044
0045 _activeContext = None
0046
0047 def __init__(self, sparkContext, batchDuration=None, jssc=None):
0048 """
0049 Create a new StreamingContext.
0050
0051 :param sparkContext: :class:`SparkContext` object.
0052 :param batchDuration: the time interval (in seconds) at which streaming
0053 data will be divided into batches
0054 """
0055
0056 self._sc = sparkContext
0057 self._jvm = self._sc._jvm
0058 self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
0059
0060 def _initialize_context(self, sc, duration):
0061 self._ensure_initialized()
0062 return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
0063
0064 def _jduration(self, seconds):
0065 """
0066 Create Duration object given number of seconds
0067 """
0068 return self._jvm.Duration(int(seconds * 1000))
0069
0070 @classmethod
0071 def _ensure_initialized(cls):
0072 SparkContext._ensure_initialized()
0073 gw = SparkContext._gateway
0074
0075 java_import(gw.jvm, "org.apache.spark.streaming.*")
0076 java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
0077 java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
0078
0079 from pyspark.java_gateway import ensure_callback_server_started
0080 ensure_callback_server_started(gw)
0081
0082
0083
0084 cls._transformerSerializer = TransformFunctionSerializer(
0085 SparkContext._active_spark_context, CloudPickleSerializer(), gw)
0086
0087 @classmethod
0088 def getOrCreate(cls, checkpointPath, setupFunc):
0089 """
0090 Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
0091 If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
0092 recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
0093 will be used to create a new context.
0094
0095 :param checkpointPath: Checkpoint directory used in an earlier streaming program
0096 :param setupFunc: Function to create a new context and setup DStreams
0097 """
0098 cls._ensure_initialized()
0099 gw = SparkContext._gateway
0100
0101
0102 ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
0103 if ssc_option.isEmpty():
0104 ssc = setupFunc()
0105 ssc.checkpoint(checkpointPath)
0106 return ssc
0107
0108 jssc = gw.jvm.JavaStreamingContext(ssc_option.get())
0109
0110
0111 if not SparkContext._active_spark_context:
0112 jsc = jssc.sparkContext()
0113 conf = SparkConf(_jconf=jsc.getConf())
0114 SparkContext(conf=conf, gateway=gw, jsc=jsc)
0115
0116 sc = SparkContext._active_spark_context
0117
0118
0119 cls._transformerSerializer.ctx = sc
0120 return StreamingContext(sc, None, jssc)
0121
0122 @classmethod
0123 def getActive(cls):
0124 """
0125 Return either the currently active StreamingContext (i.e., if there is a context started
0126 but not stopped) or None.
0127 """
0128 activePythonContext = cls._activeContext
0129 if activePythonContext is not None:
0130
0131
0132 activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode()
0133 activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive()
0134
0135 if activeJvmContextOption.isEmpty():
0136 cls._activeContext = None
0137 elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId:
0138 cls._activeContext = None
0139 raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext "
0140 "backing the action Python StreamingContext. This is unexpected.")
0141 return cls._activeContext
0142
0143 @classmethod
0144 def getActiveOrCreate(cls, checkpointPath, setupFunc):
0145 """
0146 Either return the active StreamingContext (i.e. currently started but not stopped),
0147 or recreate a StreamingContext from checkpoint data or create a new StreamingContext
0148 using the provided setupFunc function. If the checkpointPath is None or does not contain
0149 valid checkpoint data, then setupFunc will be called to create a new context and setup
0150 DStreams.
0151
0152 :param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be
0153 None if the intention is to always create a new context when there
0154 is no active context.
0155 :param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
0156 """
0157
0158 if setupFunc is None:
0159 raise Exception("setupFunc cannot be None")
0160 activeContext = cls.getActive()
0161 if activeContext is not None:
0162 return activeContext
0163 elif checkpointPath is not None:
0164 return cls.getOrCreate(checkpointPath, setupFunc)
0165 else:
0166 return setupFunc()
0167
0168 @property
0169 def sparkContext(self):
0170 """
0171 Return SparkContext which is associated with this StreamingContext.
0172 """
0173 return self._sc
0174
0175 def start(self):
0176 """
0177 Start the execution of the streams.
0178 """
0179 self._jssc.start()
0180 StreamingContext._activeContext = self
0181
0182 def awaitTermination(self, timeout=None):
0183 """
0184 Wait for the execution to stop.
0185
0186 :param timeout: time to wait in seconds
0187 """
0188 if timeout is None:
0189 self._jssc.awaitTermination()
0190 else:
0191 self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
0192
0193 def awaitTerminationOrTimeout(self, timeout):
0194 """
0195 Wait for the execution to stop. Return `true` if it's stopped; or
0196 throw the reported error during the execution; or `false` if the
0197 waiting time elapsed before returning from the method.
0198
0199 :param timeout: time to wait in seconds
0200 """
0201 return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
0202
0203 def stop(self, stopSparkContext=True, stopGraceFully=False):
0204 """
0205 Stop the execution of the streams, with option of ensuring all
0206 received data has been processed.
0207
0208 :param stopSparkContext: Stop the associated SparkContext or not
0209 :param stopGracefully: Stop gracefully by waiting for the processing
0210 of all received data to be completed
0211 """
0212 self._jssc.stop(stopSparkContext, stopGraceFully)
0213 StreamingContext._activeContext = None
0214 if stopSparkContext:
0215 self._sc.stop()
0216
0217 def remember(self, duration):
0218 """
0219 Set each DStreams in this context to remember RDDs it generated
0220 in the last given duration. DStreams remember RDDs only for a
0221 limited duration of time and releases them for garbage collection.
0222 This method allows the developer to specify how long to remember
0223 the RDDs (if the developer wishes to query old data outside the
0224 DStream computation).
0225
0226 :param duration: Minimum duration (in seconds) that each DStream
0227 should remember its RDDs
0228 """
0229 self._jssc.remember(self._jduration(duration))
0230
0231 def checkpoint(self, directory):
0232 """
0233 Sets the context to periodically checkpoint the DStream operations for master
0234 fault-tolerance. The graph will be checkpointed every batch interval.
0235
0236 :param directory: HDFS-compatible directory where the checkpoint data
0237 will be reliably stored
0238 """
0239 self._jssc.checkpoint(directory)
0240
0241 def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_2):
0242 """
0243 Create an input from TCP source hostname:port. Data is received using
0244 a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited
0245 lines.
0246
0247 :param hostname: Hostname to connect to for receiving data
0248 :param port: Port to connect to for receiving data
0249 :param storageLevel: Storage level to use for storing the received objects
0250 """
0251 jlevel = self._sc._getJavaStorageLevel(storageLevel)
0252 return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
0253 UTF8Deserializer())
0254
0255 def textFileStream(self, directory):
0256 """
0257 Create an input stream that monitors a Hadoop-compatible file system
0258 for new files and reads them as text files. Files must be wrriten to the
0259 monitored directory by "moving" them from another location within the same
0260 file system. File names starting with . are ignored.
0261 The text files must be encoded as UTF-8.
0262 """
0263 return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
0264
0265 def binaryRecordsStream(self, directory, recordLength):
0266 """
0267 Create an input stream that monitors a Hadoop-compatible file system
0268 for new files and reads them as flat binary files with records of
0269 fixed length. Files must be written to the monitored directory by "moving"
0270 them from another location within the same file system.
0271 File names starting with . are ignored.
0272
0273 :param directory: Directory to load data from
0274 :param recordLength: Length of each record in bytes
0275 """
0276 return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self,
0277 NoOpSerializer())
0278
0279 def _check_serializers(self, rdds):
0280
0281 if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
0282 for i in range(len(rdds)):
0283
0284 rdds[i] = rdds[i]._reserialize()
0285
0286 def queueStream(self, rdds, oneAtATime=True, default=None):
0287 """
0288 Create an input stream from a queue of RDDs or list. In each batch,
0289 it will process either one or all of the RDDs returned by the queue.
0290
0291 .. note:: Changes to the queue after the stream is created will not be recognized.
0292
0293 :param rdds: Queue of RDDs
0294 :param oneAtATime: pick one rdd each time or pick all of them once.
0295 :param default: The default rdd if no more in rdds
0296 """
0297 if default and not isinstance(default, RDD):
0298 default = self._sc.parallelize(default)
0299
0300 if not rdds and default:
0301 rdds = [rdds]
0302
0303 if rdds and not isinstance(rdds[0], RDD):
0304 rdds = [self._sc.parallelize(input) for input in rdds]
0305 self._check_serializers(rdds)
0306
0307 queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
0308 if default:
0309 default = default._reserialize(rdds[0]._jrdd_deserializer)
0310 jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
0311 else:
0312 jdstream = self._jssc.queueStream(queue, oneAtATime)
0313 return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
0314
0315 def transform(self, dstreams, transformFunc):
0316 """
0317 Create a new DStream in which each RDD is generated by applying
0318 a function on RDDs of the DStreams. The order of the JavaRDDs in
0319 the transform function parameter will be the same as the order
0320 of corresponding DStreams in the list.
0321 """
0322 jdstreams = [d._jdstream for d in dstreams]
0323
0324 func = TransformFunction(self._sc,
0325 lambda t, *rdds: transformFunc(rdds),
0326 *[d._jrdd_deserializer for d in dstreams])
0327 jfunc = self._jvm.TransformFunction(func)
0328 jdstream = self._jssc.transform(jdstreams, jfunc)
0329 return DStream(jdstream, self, self._sc.serializer)
0330
0331 def union(self, *dstreams):
0332 """
0333 Create a unified DStream from multiple DStreams of the same
0334 type and same slide duration.
0335 """
0336 if not dstreams:
0337 raise ValueError("should have at least one DStream to union")
0338 if len(dstreams) == 1:
0339 return dstreams[0]
0340 if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
0341 raise ValueError("All DStreams should have same serializer")
0342 if len(set(s._slideDuration for s in dstreams)) > 1:
0343 raise ValueError("All DStreams should have same slide duration")
0344 jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
0345 jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
0346 gw = SparkContext._gateway
0347 if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
0348 cls = jdstream_cls
0349 elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
0350 cls = jpair_dstream_cls
0351 else:
0352 cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
0353 raise TypeError("Unsupported Java DStream class %s" % cls_name)
0354 jdstreams = gw.new_array(cls, len(dstreams))
0355 for i in range(0, len(dstreams)):
0356 jdstreams[i] = dstreams[i]._jdstream
0357 return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer)
0358
0359 def addStreamingListener(self, streamingListener):
0360 """
0361 Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
0362 receiving system events related to streaming.
0363 """
0364 self._jssc.addStreamingListener(self._jvm.JavaStreamingListenerWrapper(
0365 self._jvm.PythonStreamingListenerWrapper(streamingListener)))