Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 if sys.version >= '3':
0020     long = int
0021     unicode = str
0022 
0023 import py4j.protocol
0024 from py4j.protocol import Py4JJavaError
0025 from py4j.java_gateway import JavaObject
0026 from py4j.java_collections import JavaArray, JavaList
0027 
0028 from pyspark import RDD, SparkContext
0029 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
0030 from pyspark.sql import DataFrame, SQLContext
0031 
0032 # Hack for support float('inf') in Py4j
0033 _old_smart_decode = py4j.protocol.smart_decode
0034 
0035 _float_str_mapping = {
0036     'nan': 'NaN',
0037     'inf': 'Infinity',
0038     '-inf': '-Infinity',
0039 }
0040 
0041 
0042 def _new_smart_decode(obj):
0043     if isinstance(obj, float):
0044         s = str(obj)
0045         return _float_str_mapping.get(s, s)
0046     return _old_smart_decode(obj)
0047 
0048 py4j.protocol.smart_decode = _new_smart_decode
0049 
0050 
0051 _picklable_classes = [
0052     'SparseVector',
0053     'DenseVector',
0054     'SparseMatrix',
0055     'DenseMatrix',
0056 ]
0057 
0058 
0059 # this will call the ML version of pythonToJava()
0060 def _to_java_object_rdd(rdd):
0061     """ Return an JavaRDD of Object by unpickling
0062 
0063     It will convert each Python object into Java object by Pyrolite, whenever the
0064     RDD is serialized in batch or not.
0065     """
0066     rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
0067     return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)
0068 
0069 
0070 def _py2java(sc, obj):
0071     """ Convert Python object into Java """
0072     if isinstance(obj, RDD):
0073         obj = _to_java_object_rdd(obj)
0074     elif isinstance(obj, DataFrame):
0075         obj = obj._jdf
0076     elif isinstance(obj, SparkContext):
0077         obj = obj._jsc
0078     elif isinstance(obj, list):
0079         obj = [_py2java(sc, x) for x in obj]
0080     elif isinstance(obj, JavaObject):
0081         pass
0082     elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
0083         pass
0084     else:
0085         data = bytearray(PickleSerializer().dumps(obj))
0086         obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
0087     return obj
0088 
0089 
0090 def _java2py(sc, r, encoding="bytes"):
0091     if isinstance(r, JavaObject):
0092         clsName = r.getClass().getSimpleName()
0093         # convert RDD into JavaRDD
0094         if clsName != 'JavaRDD' and clsName.endswith("RDD"):
0095             r = r.toJavaRDD()
0096             clsName = 'JavaRDD'
0097 
0098         if clsName == 'JavaRDD':
0099             jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
0100             return RDD(jrdd, sc)
0101 
0102         if clsName == 'Dataset':
0103             return DataFrame(r, SQLContext.getOrCreate(sc))
0104 
0105         if clsName in _picklable_classes:
0106             r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
0107         elif isinstance(r, (JavaArray, JavaList)):
0108             try:
0109                 r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
0110             except Py4JJavaError:
0111                 pass  # not pickable
0112 
0113     if isinstance(r, (bytearray, bytes)):
0114         r = PickleSerializer().loads(bytes(r), encoding=encoding)
0115     return r
0116 
0117 
0118 def callJavaFunc(sc, func, *args):
0119     """ Call Java Function """
0120     args = [_py2java(sc, a) for a in args]
0121     return _java2py(sc, func(*args))
0122 
0123 
0124 def inherit_doc(cls):
0125     """
0126     A decorator that makes a class inherit documentation from its parents.
0127     """
0128     for name, func in vars(cls).items():
0129         # only inherit docstring for public functions
0130         if name.startswith("_"):
0131             continue
0132         if not func.__doc__:
0133             for parent in cls.__bases__:
0134                 parent_func = getattr(parent, name, None)
0135                 if parent_func and getattr(parent_func, "__doc__", None):
0136                     func.__doc__ = parent_func.__doc__
0137                     break
0138     return cls