0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 if sys.version >= '3':
0020 long = int
0021 unicode = str
0022
0023 import py4j.protocol
0024 from py4j.protocol import Py4JJavaError
0025 from py4j.java_gateway import JavaObject
0026 from py4j.java_collections import JavaArray, JavaList
0027
0028 from pyspark import RDD, SparkContext
0029 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
0030 from pyspark.sql import DataFrame, SQLContext
0031
0032
0033 _old_smart_decode = py4j.protocol.smart_decode
0034
0035 _float_str_mapping = {
0036 'nan': 'NaN',
0037 'inf': 'Infinity',
0038 '-inf': '-Infinity',
0039 }
0040
0041
0042 def _new_smart_decode(obj):
0043 if isinstance(obj, float):
0044 s = str(obj)
0045 return _float_str_mapping.get(s, s)
0046 return _old_smart_decode(obj)
0047
0048 py4j.protocol.smart_decode = _new_smart_decode
0049
0050
0051 _picklable_classes = [
0052 'LinkedList',
0053 'SparseVector',
0054 'DenseVector',
0055 'DenseMatrix',
0056 'Rating',
0057 'LabeledPoint',
0058 ]
0059
0060
0061
0062 def _to_java_object_rdd(rdd):
0063 """ Return a JavaRDD of Object by unpickling
0064
0065 It will convert each Python object into Java object by Pyrolite, whenever the
0066 RDD is serialized in batch or not.
0067 """
0068 rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
0069 return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True)
0070
0071
0072 def _py2java(sc, obj):
0073 """ Convert Python object into Java """
0074 if isinstance(obj, RDD):
0075 obj = _to_java_object_rdd(obj)
0076 elif isinstance(obj, DataFrame):
0077 obj = obj._jdf
0078 elif isinstance(obj, SparkContext):
0079 obj = obj._jsc
0080 elif isinstance(obj, list):
0081 obj = [_py2java(sc, x) for x in obj]
0082 elif isinstance(obj, JavaObject):
0083 pass
0084 elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
0085 pass
0086 else:
0087 data = bytearray(PickleSerializer().dumps(obj))
0088 obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)
0089 return obj
0090
0091
0092 def _java2py(sc, r, encoding="bytes"):
0093 if isinstance(r, JavaObject):
0094 clsName = r.getClass().getSimpleName()
0095
0096 if clsName != 'JavaRDD' and clsName.endswith("RDD"):
0097 r = r.toJavaRDD()
0098 clsName = 'JavaRDD'
0099
0100 if clsName == 'JavaRDD':
0101 jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)
0102 return RDD(jrdd, sc)
0103
0104 if clsName == 'Dataset':
0105 return DataFrame(r, SQLContext.getOrCreate(sc))
0106
0107 if clsName in _picklable_classes:
0108 r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
0109 elif isinstance(r, (JavaArray, JavaList)):
0110 try:
0111 r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
0112 except Py4JJavaError:
0113 pass
0114
0115 if isinstance(r, (bytearray, bytes)):
0116 r = PickleSerializer().loads(bytes(r), encoding=encoding)
0117 return r
0118
0119
0120 def callJavaFunc(sc, func, *args):
0121 """ Call Java Function """
0122 args = [_py2java(sc, a) for a in args]
0123 return _java2py(sc, func(*args))
0124
0125
0126 def callMLlibFunc(name, *args):
0127 """ Call API in PythonMLLibAPI """
0128 sc = SparkContext.getOrCreate()
0129 api = getattr(sc._jvm.PythonMLLibAPI(), name)
0130 return callJavaFunc(sc, api, *args)
0131
0132
0133 class JavaModelWrapper(object):
0134 """
0135 Wrapper for the model in JVM
0136 """
0137 def __init__(self, java_model):
0138 self._sc = SparkContext.getOrCreate()
0139 self._java_model = java_model
0140
0141 def __del__(self):
0142 self._sc._gateway.detach(self._java_model)
0143
0144 def call(self, name, *a):
0145 """Call method of java_model"""
0146 return callJavaFunc(self._sc, getattr(self._java_model, name), *a)
0147
0148
0149 def inherit_doc(cls):
0150 """
0151 A decorator that makes a class inherit documentation from its parents.
0152 """
0153 for name, func in vars(cls).items():
0154
0155 if name.startswith("_"):
0156 continue
0157 if not func.__doc__:
0158 for parent in cls.__bases__:
0159 parent_func = getattr(parent, name, None)
0160 if parent_func and getattr(parent_func, "__doc__", None):
0161 func.__doc__ = parent_func.__doc__
0162 break
0163 return cls