0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import time
0019 from datetime import datetime
0020 import traceback
0021 import sys
0022
0023 from py4j.java_gateway import is_instance_of
0024
0025 from pyspark import SparkContext, RDD
0026
0027
0028 class TransformFunction(object):
0029 """
0030 This class wraps a function RDD[X] -> RDD[Y] that was passed to
0031 DStream.transform(), allowing it to be called from Java via Py4J's
0032 callback server.
0033
0034 Java calls this function with a sequence of JavaRDDs and this function
0035 returns a single JavaRDD pointer back to Java.
0036 """
0037 _emptyRDD = None
0038
0039 def __init__(self, ctx, func, *deserializers):
0040 self.ctx = ctx
0041 self.func = func
0042 self.deserializers = deserializers
0043 self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
0044 self.failure = None
0045
0046 def rdd_wrapper(self, func):
0047 self.rdd_wrap_func = func
0048 return self
0049
0050 def call(self, milliseconds, jrdds):
0051
0052 self.failure = None
0053 try:
0054 if self.ctx is None:
0055 self.ctx = SparkContext._active_spark_context
0056 if not self.ctx or not self.ctx._jsc:
0057
0058 return
0059
0060
0061 sers = self.deserializers
0062 if len(sers) < len(jrdds):
0063 sers += (sers[0],) * (len(jrdds) - len(sers))
0064
0065 rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
0066 for jrdd, ser in zip(jrdds, sers)]
0067 t = datetime.fromtimestamp(milliseconds / 1000.0)
0068 r = self.func(t, *rdds)
0069 if r:
0070
0071
0072
0073
0074 if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"):
0075 return r._jrdd
0076 else:
0077 return r.map(lambda x: x)._jrdd
0078 except:
0079 self.failure = traceback.format_exc()
0080
0081 def getLastFailure(self):
0082 return self.failure
0083
0084 def __repr__(self):
0085 return "TransformFunction(%s)" % self.func
0086
0087 class Java:
0088 implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']
0089
0090
0091 class TransformFunctionSerializer(object):
0092 """
0093 This class implements a serializer for PythonTransformFunction Java
0094 objects.
0095
0096 This is necessary because the Java PythonTransformFunction objects are
0097 actually Py4J references to Python objects and thus are not directly
0098 serializable. When Java needs to serialize a PythonTransformFunction,
0099 it uses this class to invoke Python, which returns the serialized function
0100 as a byte array.
0101 """
0102 def __init__(self, ctx, serializer, gateway=None):
0103 self.ctx = ctx
0104 self.serializer = serializer
0105 self.gateway = gateway or self.ctx._gateway
0106 self.gateway.jvm.PythonDStream.registerSerializer(self)
0107 self.failure = None
0108
0109 def dumps(self, id):
0110
0111 self.failure = None
0112 try:
0113 func = self.gateway.gateway_property.pool[id]
0114 return bytearray(self.serializer.dumps((
0115 func.func, func.rdd_wrap_func, func.deserializers)))
0116 except:
0117 self.failure = traceback.format_exc()
0118
0119 def loads(self, data):
0120
0121 self.failure = None
0122 try:
0123 f, wrap_func, deserializers = self.serializer.loads(bytes(data))
0124 return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func)
0125 except:
0126 self.failure = traceback.format_exc()
0127
0128 def getLastFailure(self):
0129 return self.failure
0130
0131 def __repr__(self):
0132 return "TransformFunctionSerializer(%s)" % self.serializer
0133
0134 class Java:
0135 implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
0136
0137
0138 def rddToFileName(prefix, suffix, timestamp):
0139 """
0140 Return string prefix-time(.suffix)
0141
0142 >>> rddToFileName("spark", None, 12345678910)
0143 'spark-12345678910'
0144 >>> rddToFileName("spark", "tmp", 12345678910)
0145 'spark-12345678910.tmp'
0146 """
0147 if isinstance(timestamp, datetime):
0148 seconds = time.mktime(timestamp.timetuple())
0149 timestamp = int(seconds * 1000) + timestamp.microsecond // 1000
0150 if suffix is None:
0151 return prefix + "-" + str(timestamp)
0152 else:
0153 return prefix + "-" + str(timestamp) + "." + suffix
0154
0155
0156 if __name__ == "__main__":
0157 import doctest
0158 (failure_count, test_count) = doctest.testmod()
0159 if failure_count:
0160 sys.exit(-1)