Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import time
0019 from datetime import datetime
0020 import traceback
0021 import sys
0022 
0023 from py4j.java_gateway import is_instance_of
0024 
0025 from pyspark import SparkContext, RDD
0026 
0027 
0028 class TransformFunction(object):
0029     """
0030     This class wraps a function RDD[X] -> RDD[Y] that was passed to
0031     DStream.transform(), allowing it to be called from Java via Py4J's
0032     callback server.
0033 
0034     Java calls this function with a sequence of JavaRDDs and this function
0035     returns a single JavaRDD pointer back to Java.
0036     """
0037     _emptyRDD = None
0038 
0039     def __init__(self, ctx, func, *deserializers):
0040         self.ctx = ctx
0041         self.func = func
0042         self.deserializers = deserializers
0043         self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
0044         self.failure = None
0045 
0046     def rdd_wrapper(self, func):
0047         self.rdd_wrap_func = func
0048         return self
0049 
0050     def call(self, milliseconds, jrdds):
0051         # Clear the failure
0052         self.failure = None
0053         try:
0054             if self.ctx is None:
0055                 self.ctx = SparkContext._active_spark_context
0056             if not self.ctx or not self.ctx._jsc:
0057                 # stopped
0058                 return
0059 
0060             # extend deserializers with the first one
0061             sers = self.deserializers
0062             if len(sers) < len(jrdds):
0063                 sers += (sers[0],) * (len(jrdds) - len(sers))
0064 
0065             rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
0066                     for jrdd, ser in zip(jrdds, sers)]
0067             t = datetime.fromtimestamp(milliseconds / 1000.0)
0068             r = self.func(t, *rdds)
0069             if r:
0070                 # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`.
0071                 # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return
0072                 # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`.
0073                 # See SPARK-17756.
0074                 if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"):
0075                     return r._jrdd
0076                 else:
0077                     return r.map(lambda x: x)._jrdd
0078         except:
0079             self.failure = traceback.format_exc()
0080 
0081     def getLastFailure(self):
0082         return self.failure
0083 
0084     def __repr__(self):
0085         return "TransformFunction(%s)" % self.func
0086 
0087     class Java:
0088         implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']
0089 
0090 
0091 class TransformFunctionSerializer(object):
0092     """
0093     This class implements a serializer for PythonTransformFunction Java
0094     objects.
0095 
0096     This is necessary because the Java PythonTransformFunction objects are
0097     actually Py4J references to Python objects and thus are not directly
0098     serializable. When Java needs to serialize a PythonTransformFunction,
0099     it uses this class to invoke Python, which returns the serialized function
0100     as a byte array.
0101     """
0102     def __init__(self, ctx, serializer, gateway=None):
0103         self.ctx = ctx
0104         self.serializer = serializer
0105         self.gateway = gateway or self.ctx._gateway
0106         self.gateway.jvm.PythonDStream.registerSerializer(self)
0107         self.failure = None
0108 
0109     def dumps(self, id):
0110         # Clear the failure
0111         self.failure = None
0112         try:
0113             func = self.gateway.gateway_property.pool[id]
0114             return bytearray(self.serializer.dumps((
0115                 func.func, func.rdd_wrap_func, func.deserializers)))
0116         except:
0117             self.failure = traceback.format_exc()
0118 
0119     def loads(self, data):
0120         # Clear the failure
0121         self.failure = None
0122         try:
0123             f, wrap_func, deserializers = self.serializer.loads(bytes(data))
0124             return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func)
0125         except:
0126             self.failure = traceback.format_exc()
0127 
0128     def getLastFailure(self):
0129         return self.failure
0130 
0131     def __repr__(self):
0132         return "TransformFunctionSerializer(%s)" % self.serializer
0133 
0134     class Java:
0135         implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
0136 
0137 
0138 def rddToFileName(prefix, suffix, timestamp):
0139     """
0140     Return string prefix-time(.suffix)
0141 
0142     >>> rddToFileName("spark", None, 12345678910)
0143     'spark-12345678910'
0144     >>> rddToFileName("spark", "tmp", 12345678910)
0145     'spark-12345678910.tmp'
0146     """
0147     if isinstance(timestamp, datetime):
0148         seconds = time.mktime(timestamp.timetuple())
0149         timestamp = int(seconds * 1000) + timestamp.microsecond // 1000
0150     if suffix is None:
0151         return prefix + "-" + str(timestamp)
0152     else:
0153         return prefix + "-" + str(timestamp) + "." + suffix
0154 
0155 
0156 if __name__ == "__main__":
0157     import doctest
0158     (failure_count, test_count) = doctest.testmod()
0159     if failure_count:
0160         sys.exit(-1)