0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 if sys.version >= '3':
0020 long = int
0021 unicode = str
0022
0023 import py4j.protocol
0024 from py4j.protocol import Py4JJavaError
0025 from py4j.java_gateway import JavaObject
0026 from py4j.java_collections import JavaArray, JavaList
0027
0028 from pyspark import RDD, SparkContext
0029 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
0030 from pyspark.sql import DataFrame, SQLContext
0031
0032
0033 _old_smart_decode = py4j.protocol.smart_decode
0034
0035 _float_str_mapping = {
0036 'nan': 'NaN',
0037 'inf': 'Infinity',
0038 '-inf': '-Infinity',
0039 }
0040
0041
0042 def _new_smart_decode(obj):
0043 if isinstance(obj, float):
0044 s = str(obj)
0045 return _float_str_mapping.get(s, s)
0046 return _old_smart_decode(obj)
0047
0048 py4j.protocol.smart_decode = _new_smart_decode
0049
0050
0051 _picklable_classes = [
0052 'SparseVector',
0053 'DenseVector',
0054 'SparseMatrix',
0055 'DenseMatrix',
0056 ]
0057
0058
0059
0060 def _to_java_object_rdd(rdd):
0061 """ Return an JavaRDD of Object by unpickling
0062
0063 It will convert each Python object into Java object by Pyrolite, whenever the
0064 RDD is serialized in batch or not.
0065 """
0066 rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
0067 return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)
0068
0069
0070 def _py2java(sc, obj):
0071 """ Convert Python object into Java """
0072 if isinstance(obj, RDD):
0073 obj = _to_java_object_rdd(obj)
0074 elif isinstance(obj, DataFrame):
0075 obj = obj._jdf
0076 elif isinstance(obj, SparkContext):
0077 obj = obj._jsc
0078 elif isinstance(obj, list):
0079 obj = [_py2java(sc, x) for x in obj]
0080 elif isinstance(obj, JavaObject):
0081 pass
0082 elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
0083 pass
0084 else:
0085 data = bytearray(PickleSerializer().dumps(obj))
0086 obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
0087 return obj
0088
0089
0090 def _java2py(sc, r, encoding="bytes"):
0091 if isinstance(r, JavaObject):
0092 clsName = r.getClass().getSimpleName()
0093
0094 if clsName != 'JavaRDD' and clsName.endswith("RDD"):
0095 r = r.toJavaRDD()
0096 clsName = 'JavaRDD'
0097
0098 if clsName == 'JavaRDD':
0099 jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
0100 return RDD(jrdd, sc)
0101
0102 if clsName == 'Dataset':
0103 return DataFrame(r, SQLContext.getOrCreate(sc))
0104
0105 if clsName in _picklable_classes:
0106 r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
0107 elif isinstance(r, (JavaArray, JavaList)):
0108 try:
0109 r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
0110 except Py4JJavaError:
0111 pass
0112
0113 if isinstance(r, (bytearray, bytes)):
0114 r = PickleSerializer().loads(bytes(r), encoding=encoding)
0115 return r
0116
0117
0118 def callJavaFunc(sc, func, *args):
0119 """ Call Java Function """
0120 args = [_py2java(sc, a) for a in args]
0121 return _java2py(sc, func(*args))
0122
0123
0124 def inherit_doc(cls):
0125 """
0126 A decorator that makes a class inherit documentation from its parents.
0127 """
0128 for name, func in vars(cls).items():
0129
0130 if name.startswith("_"):
0131 continue
0132 if not func.__doc__:
0133 for parent in cls.__bases__:
0134 parent_func = getattr(parent, name, None)
0135 if parent_func and getattr(parent_func, "__doc__", None):
0136 func.__doc__ = parent_func.__doc__
0137 break
0138 return cls