Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import sys
0020 import decimal
0021 import time
0022 import datetime
0023 import calendar
0024 import json
0025 import re
0026 import base64
0027 from array import array
0028 import ctypes
0029 import warnings
0030 
0031 if sys.version >= "3":
0032     long = int
0033     basestring = unicode = str
0034 
0035 from py4j.protocol import register_input_converter
0036 from py4j.java_gateway import JavaClass
0037 
0038 from pyspark import SparkContext
0039 from pyspark.serializers import CloudPickleSerializer
0040 
0041 __all__ = [
0042     "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
0043     "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
0044     "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
0045 
0046 
0047 class DataType(object):
0048     """Base class for data types."""
0049 
0050     def __repr__(self):
0051         return self.__class__.__name__
0052 
0053     def __hash__(self):
0054         return hash(str(self))
0055 
0056     def __eq__(self, other):
0057         return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
0058 
0059     def __ne__(self, other):
0060         return not self.__eq__(other)
0061 
0062     @classmethod
0063     def typeName(cls):
0064         return cls.__name__[:-4].lower()
0065 
0066     def simpleString(self):
0067         return self.typeName()
0068 
0069     def jsonValue(self):
0070         return self.typeName()
0071 
0072     def json(self):
0073         return json.dumps(self.jsonValue(),
0074                           separators=(',', ':'),
0075                           sort_keys=True)
0076 
0077     def needConversion(self):
0078         """
0079         Does this type needs conversion between Python object and internal SQL object.
0080 
0081         This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
0082         """
0083         return False
0084 
0085     def toInternal(self, obj):
0086         """
0087         Converts a Python object into an internal SQL object.
0088         """
0089         return obj
0090 
0091     def fromInternal(self, obj):
0092         """
0093         Converts an internal SQL object into a native Python object.
0094         """
0095         return obj
0096 
0097 
0098 # This singleton pattern does not work with pickle, you will get
0099 # another object after pickle and unpickle
0100 class DataTypeSingleton(type):
0101     """Metaclass for DataType"""
0102 
0103     _instances = {}
0104 
0105     def __call__(cls):
0106         if cls not in cls._instances:
0107             cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
0108         return cls._instances[cls]
0109 
0110 
0111 class NullType(DataType):
0112     """Null type.
0113 
0114     The data type representing None, used for the types that cannot be inferred.
0115     """
0116 
0117     __metaclass__ = DataTypeSingleton
0118 
0119 
0120 class AtomicType(DataType):
0121     """An internal type used to represent everything that is not
0122     null, UDTs, arrays, structs, and maps."""
0123 
0124 
0125 class NumericType(AtomicType):
0126     """Numeric data types.
0127     """
0128 
0129 
0130 class IntegralType(NumericType):
0131     """Integral data types.
0132     """
0133 
0134     __metaclass__ = DataTypeSingleton
0135 
0136 
0137 class FractionalType(NumericType):
0138     """Fractional data types.
0139     """
0140 
0141 
0142 class StringType(AtomicType):
0143     """String data type.
0144     """
0145 
0146     __metaclass__ = DataTypeSingleton
0147 
0148 
0149 class BinaryType(AtomicType):
0150     """Binary (byte array) data type.
0151     """
0152 
0153     __metaclass__ = DataTypeSingleton
0154 
0155 
0156 class BooleanType(AtomicType):
0157     """Boolean data type.
0158     """
0159 
0160     __metaclass__ = DataTypeSingleton
0161 
0162 
0163 class DateType(AtomicType):
0164     """Date (datetime.date) data type.
0165     """
0166 
0167     __metaclass__ = DataTypeSingleton
0168 
0169     EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
0170 
0171     def needConversion(self):
0172         return True
0173 
0174     def toInternal(self, d):
0175         if d is not None:
0176             return d.toordinal() - self.EPOCH_ORDINAL
0177 
0178     def fromInternal(self, v):
0179         if v is not None:
0180             return datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
0181 
0182 
0183 class TimestampType(AtomicType):
0184     """Timestamp (datetime.datetime) data type.
0185     """
0186 
0187     __metaclass__ = DataTypeSingleton
0188 
0189     def needConversion(self):
0190         return True
0191 
0192     def toInternal(self, dt):
0193         if dt is not None:
0194             seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
0195                        else time.mktime(dt.timetuple()))
0196             return int(seconds) * 1000000 + dt.microsecond
0197 
0198     def fromInternal(self, ts):
0199         if ts is not None:
0200             # using int to avoid precision loss in float
0201             return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000)
0202 
0203 
0204 class DecimalType(FractionalType):
0205     """Decimal (decimal.Decimal) data type.
0206 
0207     The DecimalType must have fixed precision (the maximum total number of digits)
0208     and scale (the number of digits on the right of dot). For example, (5, 2) can
0209     support the value from [-999.99 to 999.99].
0210 
0211     The precision can be up to 38, the scale must be less or equal to precision.
0212 
0213     When creating a DecimalType, the default precision and scale is (10, 0). When inferring
0214     schema from decimal.Decimal objects, it will be DecimalType(38, 18).
0215 
0216     :param precision: the maximum (i.e. total) number of digits (default: 10)
0217     :param scale: the number of digits on right side of dot. (default: 0)
0218     """
0219 
0220     def __init__(self, precision=10, scale=0):
0221         self.precision = precision
0222         self.scale = scale
0223         self.hasPrecisionInfo = True  # this is a public API
0224 
0225     def simpleString(self):
0226         return "decimal(%d,%d)" % (self.precision, self.scale)
0227 
0228     def jsonValue(self):
0229         return "decimal(%d,%d)" % (self.precision, self.scale)
0230 
0231     def __repr__(self):
0232         return "DecimalType(%d,%d)" % (self.precision, self.scale)
0233 
0234 
0235 class DoubleType(FractionalType):
0236     """Double data type, representing double precision floats.
0237     """
0238 
0239     __metaclass__ = DataTypeSingleton
0240 
0241 
0242 class FloatType(FractionalType):
0243     """Float data type, representing single precision floats.
0244     """
0245 
0246     __metaclass__ = DataTypeSingleton
0247 
0248 
0249 class ByteType(IntegralType):
0250     """Byte data type, i.e. a signed integer in a single byte.
0251     """
0252     def simpleString(self):
0253         return 'tinyint'
0254 
0255 
0256 class IntegerType(IntegralType):
0257     """Int data type, i.e. a signed 32-bit integer.
0258     """
0259     def simpleString(self):
0260         return 'int'
0261 
0262 
0263 class LongType(IntegralType):
0264     """Long data type, i.e. a signed 64-bit integer.
0265 
0266     If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
0267     please use :class:`DecimalType`.
0268     """
0269     def simpleString(self):
0270         return 'bigint'
0271 
0272 
0273 class ShortType(IntegralType):
0274     """Short data type, i.e. a signed 16-bit integer.
0275     """
0276     def simpleString(self):
0277         return 'smallint'
0278 
0279 
0280 class ArrayType(DataType):
0281     """Array data type.
0282 
0283     :param elementType: :class:`DataType` of each element in the array.
0284     :param containsNull: boolean, whether the array can contain null (None) values.
0285     """
0286 
0287     def __init__(self, elementType, containsNull=True):
0288         """
0289         >>> ArrayType(StringType()) == ArrayType(StringType(), True)
0290         True
0291         >>> ArrayType(StringType(), False) == ArrayType(StringType())
0292         False
0293         """
0294         assert isinstance(elementType, DataType),\
0295             "elementType %s should be an instance of %s" % (elementType, DataType)
0296         self.elementType = elementType
0297         self.containsNull = containsNull
0298 
0299     def simpleString(self):
0300         return 'array<%s>' % self.elementType.simpleString()
0301 
0302     def __repr__(self):
0303         return "ArrayType(%s,%s)" % (self.elementType,
0304                                      str(self.containsNull).lower())
0305 
0306     def jsonValue(self):
0307         return {"type": self.typeName(),
0308                 "elementType": self.elementType.jsonValue(),
0309                 "containsNull": self.containsNull}
0310 
0311     @classmethod
0312     def fromJson(cls, json):
0313         return ArrayType(_parse_datatype_json_value(json["elementType"]),
0314                          json["containsNull"])
0315 
0316     def needConversion(self):
0317         return self.elementType.needConversion()
0318 
0319     def toInternal(self, obj):
0320         if not self.needConversion():
0321             return obj
0322         return obj and [self.elementType.toInternal(v) for v in obj]
0323 
0324     def fromInternal(self, obj):
0325         if not self.needConversion():
0326             return obj
0327         return obj and [self.elementType.fromInternal(v) for v in obj]
0328 
0329 
0330 class MapType(DataType):
0331     """Map data type.
0332 
0333     :param keyType: :class:`DataType` of the keys in the map.
0334     :param valueType: :class:`DataType` of the values in the map.
0335     :param valueContainsNull: indicates whether values can contain null (None) values.
0336 
0337     Keys in a map data type are not allowed to be null (None).
0338     """
0339 
0340     def __init__(self, keyType, valueType, valueContainsNull=True):
0341         """
0342         >>> (MapType(StringType(), IntegerType())
0343         ...        == MapType(StringType(), IntegerType(), True))
0344         True
0345         >>> (MapType(StringType(), IntegerType(), False)
0346         ...        == MapType(StringType(), FloatType()))
0347         False
0348         """
0349         assert isinstance(keyType, DataType),\
0350             "keyType %s should be an instance of %s" % (keyType, DataType)
0351         assert isinstance(valueType, DataType),\
0352             "valueType %s should be an instance of %s" % (valueType, DataType)
0353         self.keyType = keyType
0354         self.valueType = valueType
0355         self.valueContainsNull = valueContainsNull
0356 
0357     def simpleString(self):
0358         return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
0359 
0360     def __repr__(self):
0361         return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
0362                                       str(self.valueContainsNull).lower())
0363 
0364     def jsonValue(self):
0365         return {"type": self.typeName(),
0366                 "keyType": self.keyType.jsonValue(),
0367                 "valueType": self.valueType.jsonValue(),
0368                 "valueContainsNull": self.valueContainsNull}
0369 
0370     @classmethod
0371     def fromJson(cls, json):
0372         return MapType(_parse_datatype_json_value(json["keyType"]),
0373                        _parse_datatype_json_value(json["valueType"]),
0374                        json["valueContainsNull"])
0375 
0376     def needConversion(self):
0377         return self.keyType.needConversion() or self.valueType.needConversion()
0378 
0379     def toInternal(self, obj):
0380         if not self.needConversion():
0381             return obj
0382         return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
0383                             for k, v in obj.items())
0384 
0385     def fromInternal(self, obj):
0386         if not self.needConversion():
0387             return obj
0388         return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
0389                             for k, v in obj.items())
0390 
0391 
0392 class StructField(DataType):
0393     """A field in :class:`StructType`.
0394 
0395     :param name: string, name of the field.
0396     :param dataType: :class:`DataType` of the field.
0397     :param nullable: boolean, whether the field can be null (None) or not.
0398     :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
0399     """
0400 
0401     def __init__(self, name, dataType, nullable=True, metadata=None):
0402         """
0403         >>> (StructField("f1", StringType(), True)
0404         ...      == StructField("f1", StringType(), True))
0405         True
0406         >>> (StructField("f1", StringType(), True)
0407         ...      == StructField("f2", StringType(), True))
0408         False
0409         """
0410         assert isinstance(dataType, DataType),\
0411             "dataType %s should be an instance of %s" % (dataType, DataType)
0412         assert isinstance(name, basestring), "field name %s should be string" % (name)
0413         if not isinstance(name, str):
0414             name = name.encode('utf-8')
0415         self.name = name
0416         self.dataType = dataType
0417         self.nullable = nullable
0418         self.metadata = metadata or {}
0419 
0420     def simpleString(self):
0421         return '%s:%s' % (self.name, self.dataType.simpleString())
0422 
0423     def __repr__(self):
0424         return "StructField(%s,%s,%s)" % (self.name, self.dataType,
0425                                           str(self.nullable).lower())
0426 
0427     def jsonValue(self):
0428         return {"name": self.name,
0429                 "type": self.dataType.jsonValue(),
0430                 "nullable": self.nullable,
0431                 "metadata": self.metadata}
0432 
0433     @classmethod
0434     def fromJson(cls, json):
0435         return StructField(json["name"],
0436                            _parse_datatype_json_value(json["type"]),
0437                            json["nullable"],
0438                            json["metadata"])
0439 
0440     def needConversion(self):
0441         return self.dataType.needConversion()
0442 
0443     def toInternal(self, obj):
0444         return self.dataType.toInternal(obj)
0445 
0446     def fromInternal(self, obj):
0447         return self.dataType.fromInternal(obj)
0448 
0449     def typeName(self):
0450         raise TypeError(
0451             "StructField does not have typeName. "
0452             "Use typeName on its type explicitly instead.")
0453 
0454 
0455 class StructType(DataType):
0456     """Struct type, consisting of a list of :class:`StructField`.
0457 
0458     This is the data type representing a :class:`Row`.
0459 
0460     Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s.
0461     A contained :class:`StructField` can be accessed by its name or position.
0462 
0463     >>> struct1 = StructType([StructField("f1", StringType(), True)])
0464     >>> struct1["f1"]
0465     StructField(f1,StringType,true)
0466     >>> struct1[0]
0467     StructField(f1,StringType,true)
0468     """
0469     def __init__(self, fields=None):
0470         """
0471         >>> struct1 = StructType([StructField("f1", StringType(), True)])
0472         >>> struct2 = StructType([StructField("f1", StringType(), True)])
0473         >>> struct1 == struct2
0474         True
0475         >>> struct1 = StructType([StructField("f1", StringType(), True)])
0476         >>> struct2 = StructType([StructField("f1", StringType(), True),
0477         ...     StructField("f2", IntegerType(), False)])
0478         >>> struct1 == struct2
0479         False
0480         """
0481         if not fields:
0482             self.fields = []
0483             self.names = []
0484         else:
0485             self.fields = fields
0486             self.names = [f.name for f in fields]
0487             assert all(isinstance(f, StructField) for f in fields),\
0488                 "fields should be a list of StructField"
0489         # Precalculated list of fields that need conversion with fromInternal/toInternal functions
0490         self._needConversion = [f.needConversion() for f in self]
0491         self._needSerializeAnyField = any(self._needConversion)
0492 
0493     def add(self, field, data_type=None, nullable=True, metadata=None):
0494         """
0495         Construct a StructType by adding new elements to it, to define the schema.
0496         The method accepts either:
0497 
0498             a) A single parameter which is a StructField object.
0499             b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
0500                metadata(optional). The data_type parameter may be either a String or a
0501                DataType object.
0502 
0503         >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0504         >>> struct2 = StructType([StructField("f1", StringType(), True), \\
0505         ...     StructField("f2", StringType(), True, None)])
0506         >>> struct1 == struct2
0507         True
0508         >>> struct1 = StructType().add(StructField("f1", StringType(), True))
0509         >>> struct2 = StructType([StructField("f1", StringType(), True)])
0510         >>> struct1 == struct2
0511         True
0512         >>> struct1 = StructType().add("f1", "string", True)
0513         >>> struct2 = StructType([StructField("f1", StringType(), True)])
0514         >>> struct1 == struct2
0515         True
0516 
0517         :param field: Either the name of the field or a StructField object
0518         :param data_type: If present, the DataType of the StructField to create
0519         :param nullable: Whether the field to add should be nullable (default True)
0520         :param metadata: Any additional metadata (default None)
0521         :return: a new updated StructType
0522         """
0523         if isinstance(field, StructField):
0524             self.fields.append(field)
0525             self.names.append(field.name)
0526         else:
0527             if isinstance(field, str) and data_type is None:
0528                 raise ValueError("Must specify DataType if passing name of struct_field to create.")
0529 
0530             if isinstance(data_type, str):
0531                 data_type_f = _parse_datatype_json_value(data_type)
0532             else:
0533                 data_type_f = data_type
0534             self.fields.append(StructField(field, data_type_f, nullable, metadata))
0535             self.names.append(field)
0536         # Precalculated list of fields that need conversion with fromInternal/toInternal functions
0537         self._needConversion = [f.needConversion() for f in self]
0538         self._needSerializeAnyField = any(self._needConversion)
0539         return self
0540 
0541     def __iter__(self):
0542         """Iterate the fields"""
0543         return iter(self.fields)
0544 
0545     def __len__(self):
0546         """Return the number of fields."""
0547         return len(self.fields)
0548 
0549     def __getitem__(self, key):
0550         """Access fields by name or slice."""
0551         if isinstance(key, str):
0552             for field in self:
0553                 if field.name == key:
0554                     return field
0555             raise KeyError('No StructField named {0}'.format(key))
0556         elif isinstance(key, int):
0557             try:
0558                 return self.fields[key]
0559             except IndexError:
0560                 raise IndexError('StructType index out of range')
0561         elif isinstance(key, slice):
0562             return StructType(self.fields[key])
0563         else:
0564             raise TypeError('StructType keys should be strings, integers or slices')
0565 
0566     def simpleString(self):
0567         return 'struct<%s>' % (','.join(f.simpleString() for f in self))
0568 
0569     def __repr__(self):
0570         return ("StructType(List(%s))" %
0571                 ",".join(str(field) for field in self))
0572 
0573     def jsonValue(self):
0574         return {"type": self.typeName(),
0575                 "fields": [f.jsonValue() for f in self]}
0576 
0577     @classmethod
0578     def fromJson(cls, json):
0579         return StructType([StructField.fromJson(f) for f in json["fields"]])
0580 
0581     def fieldNames(self):
0582         """
0583         Returns all field names in a list.
0584 
0585         >>> struct = StructType([StructField("f1", StringType(), True)])
0586         >>> struct.fieldNames()
0587         ['f1']
0588         """
0589         return list(self.names)
0590 
0591     def needConversion(self):
0592         # We need convert Row()/namedtuple into tuple()
0593         return True
0594 
0595     def toInternal(self, obj):
0596         if obj is None:
0597             return
0598 
0599         if self._needSerializeAnyField:
0600             # Only calling toInternal function for fields that need conversion
0601             if isinstance(obj, dict):
0602                 return tuple(f.toInternal(obj.get(n)) if c else obj.get(n)
0603                              for n, f, c in zip(self.names, self.fields, self._needConversion))
0604             elif isinstance(obj, (tuple, list)):
0605                 return tuple(f.toInternal(v) if c else v
0606                              for f, v, c in zip(self.fields, obj, self._needConversion))
0607             elif hasattr(obj, "__dict__"):
0608                 d = obj.__dict__
0609                 return tuple(f.toInternal(d.get(n)) if c else d.get(n)
0610                              for n, f, c in zip(self.names, self.fields, self._needConversion))
0611             else:
0612                 raise ValueError("Unexpected tuple %r with StructType" % obj)
0613         else:
0614             if isinstance(obj, dict):
0615                 return tuple(obj.get(n) for n in self.names)
0616             elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
0617                 return tuple(obj[n] for n in self.names)
0618             elif isinstance(obj, (list, tuple)):
0619                 return tuple(obj)
0620             elif hasattr(obj, "__dict__"):
0621                 d = obj.__dict__
0622                 return tuple(d.get(n) for n in self.names)
0623             else:
0624                 raise ValueError("Unexpected tuple %r with StructType" % obj)
0625 
0626     def fromInternal(self, obj):
0627         if obj is None:
0628             return
0629         if isinstance(obj, Row):
0630             # it's already converted by pickler
0631             return obj
0632         if self._needSerializeAnyField:
0633             # Only calling fromInternal function for fields that need conversion
0634             values = [f.fromInternal(v) if c else v
0635                       for f, v, c in zip(self.fields, obj, self._needConversion)]
0636         else:
0637             values = obj
0638         return _create_row(self.names, values)
0639 
0640 
0641 class UserDefinedType(DataType):
0642     """User-defined type (UDT).
0643 
0644     .. note:: WARN: Spark Internal Use Only
0645     """
0646 
0647     @classmethod
0648     def typeName(cls):
0649         return cls.__name__.lower()
0650 
0651     @classmethod
0652     def sqlType(cls):
0653         """
0654         Underlying SQL storage type for this UDT.
0655         """
0656         raise NotImplementedError("UDT must implement sqlType().")
0657 
0658     @classmethod
0659     def module(cls):
0660         """
0661         The Python module of the UDT.
0662         """
0663         raise NotImplementedError("UDT must implement module().")
0664 
0665     @classmethod
0666     def scalaUDT(cls):
0667         """
0668         The class name of the paired Scala UDT (could be '', if there
0669         is no corresponding one).
0670         """
0671         return ''
0672 
0673     def needConversion(self):
0674         return True
0675 
0676     @classmethod
0677     def _cachedSqlType(cls):
0678         """
0679         Cache the sqlType() into class, because it's heavily used in `toInternal`.
0680         """
0681         if not hasattr(cls, "_cached_sql_type"):
0682             cls._cached_sql_type = cls.sqlType()
0683         return cls._cached_sql_type
0684 
0685     def toInternal(self, obj):
0686         if obj is not None:
0687             return self._cachedSqlType().toInternal(self.serialize(obj))
0688 
0689     def fromInternal(self, obj):
0690         v = self._cachedSqlType().fromInternal(obj)
0691         if v is not None:
0692             return self.deserialize(v)
0693 
0694     def serialize(self, obj):
0695         """
0696         Converts a user-type object into a SQL datum.
0697         """
0698         raise NotImplementedError("UDT must implement toInternal().")
0699 
0700     def deserialize(self, datum):
0701         """
0702         Converts a SQL datum into a user-type object.
0703         """
0704         raise NotImplementedError("UDT must implement fromInternal().")
0705 
0706     def simpleString(self):
0707         return 'udt'
0708 
0709     def json(self):
0710         return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
0711 
0712     def jsonValue(self):
0713         if self.scalaUDT():
0714             assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
0715             schema = {
0716                 "type": "udt",
0717                 "class": self.scalaUDT(),
0718                 "pyClass": "%s.%s" % (self.module(), type(self).__name__),
0719                 "sqlType": self.sqlType().jsonValue()
0720             }
0721         else:
0722             ser = CloudPickleSerializer()
0723             b = ser.dumps(type(self))
0724             schema = {
0725                 "type": "udt",
0726                 "pyClass": "%s.%s" % (self.module(), type(self).__name__),
0727                 "serializedClass": base64.b64encode(b).decode('utf8'),
0728                 "sqlType": self.sqlType().jsonValue()
0729             }
0730         return schema
0731 
0732     @classmethod
0733     def fromJson(cls, json):
0734         pyUDT = str(json["pyClass"])  # convert unicode to str
0735         split = pyUDT.rfind(".")
0736         pyModule = pyUDT[:split]
0737         pyClass = pyUDT[split+1:]
0738         m = __import__(pyModule, globals(), locals(), [pyClass])
0739         if not hasattr(m, pyClass):
0740             s = base64.b64decode(json['serializedClass'].encode('utf-8'))
0741             UDT = CloudPickleSerializer().loads(s)
0742         else:
0743             UDT = getattr(m, pyClass)
0744         return UDT()
0745 
0746     def __eq__(self, other):
0747         return type(self) == type(other)
0748 
0749 
0750 _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
0751                  ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, NullType]
0752 _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
0753 _all_complex_types = dict((v.typeName(), v)
0754                           for v in [ArrayType, MapType, StructType])
0755 
0756 
0757 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
0758 
0759 
0760 def _parse_datatype_string(s):
0761     """
0762     Parses the given data type string to a :class:`DataType`. The data type string format equals
0763     :class:`DataType.simpleString`, except that the top level struct type can omit
0764     the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
0765     of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
0766     for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
0767     string and case-insensitive strings.
0768 
0769     >>> _parse_datatype_string("int ")
0770     IntegerType
0771     >>> _parse_datatype_string("INT ")
0772     IntegerType
0773     >>> _parse_datatype_string("a: byte, b: decimal(  16 , 8   ) ")
0774     StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
0775     >>> _parse_datatype_string("a DOUBLE, b STRING")
0776     StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
0777     >>> _parse_datatype_string("a: array< short>")
0778     StructType(List(StructField(a,ArrayType(ShortType,true),true)))
0779     >>> _parse_datatype_string(" map<string , string > ")
0780     MapType(StringType,StringType,true)
0781 
0782     >>> # Error cases
0783     >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
0784     Traceback (most recent call last):
0785         ...
0786     ParseException:...
0787     >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
0788     Traceback (most recent call last):
0789         ...
0790     ParseException:...
0791     >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
0792     Traceback (most recent call last):
0793         ...
0794     ParseException:...
0795     >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
0796     Traceback (most recent call last):
0797         ...
0798     ParseException:...
0799     """
0800     sc = SparkContext._active_spark_context
0801 
0802     def from_ddl_schema(type_str):
0803         return _parse_datatype_json_string(
0804             sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
0805 
0806     def from_ddl_datatype(type_str):
0807         return _parse_datatype_json_string(
0808             sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
0809 
0810     try:
0811         # DDL format, "fieldname datatype, fieldname datatype".
0812         return from_ddl_schema(s)
0813     except Exception as e:
0814         try:
0815             # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
0816             return from_ddl_datatype(s)
0817         except:
0818             try:
0819                 # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
0820                 return from_ddl_datatype("struct<%s>" % s.strip())
0821             except:
0822                 raise e
0823 
0824 
0825 def _parse_datatype_json_string(json_string):
0826     """Parses the given data type JSON string.
0827     >>> import pickle
0828     >>> def check_datatype(datatype):
0829     ...     pickled = pickle.loads(pickle.dumps(datatype))
0830     ...     assert datatype == pickled
0831     ...     scala_datatype = spark._jsparkSession.parseDataType(datatype.json())
0832     ...     python_datatype = _parse_datatype_json_string(scala_datatype.json())
0833     ...     assert datatype == python_datatype
0834     >>> for cls in _all_atomic_types.values():
0835     ...     check_datatype(cls())
0836 
0837     >>> # Simple ArrayType.
0838     >>> simple_arraytype = ArrayType(StringType(), True)
0839     >>> check_datatype(simple_arraytype)
0840 
0841     >>> # Simple MapType.
0842     >>> simple_maptype = MapType(StringType(), LongType())
0843     >>> check_datatype(simple_maptype)
0844 
0845     >>> # Simple StructType.
0846     >>> simple_structtype = StructType([
0847     ...     StructField("a", DecimalType(), False),
0848     ...     StructField("b", BooleanType(), True),
0849     ...     StructField("c", LongType(), True),
0850     ...     StructField("d", BinaryType(), False)])
0851     >>> check_datatype(simple_structtype)
0852 
0853     >>> # Complex StructType.
0854     >>> complex_structtype = StructType([
0855     ...     StructField("simpleArray", simple_arraytype, True),
0856     ...     StructField("simpleMap", simple_maptype, True),
0857     ...     StructField("simpleStruct", simple_structtype, True),
0858     ...     StructField("boolean", BooleanType(), False),
0859     ...     StructField("withMeta", DoubleType(), False, {"name": "age"})])
0860     >>> check_datatype(complex_structtype)
0861 
0862     >>> # Complex ArrayType.
0863     >>> complex_arraytype = ArrayType(complex_structtype, True)
0864     >>> check_datatype(complex_arraytype)
0865 
0866     >>> # Complex MapType.
0867     >>> complex_maptype = MapType(complex_structtype,
0868     ...                           complex_arraytype, False)
0869     >>> check_datatype(complex_maptype)
0870     """
0871     return _parse_datatype_json_value(json.loads(json_string))
0872 
0873 
0874 def _parse_datatype_json_value(json_value):
0875     if not isinstance(json_value, dict):
0876         if json_value in _all_atomic_types.keys():
0877             return _all_atomic_types[json_value]()
0878         elif json_value == 'decimal':
0879             return DecimalType()
0880         elif _FIXED_DECIMAL.match(json_value):
0881             m = _FIXED_DECIMAL.match(json_value)
0882             return DecimalType(int(m.group(1)), int(m.group(2)))
0883         else:
0884             raise ValueError("Could not parse datatype: %s" % json_value)
0885     else:
0886         tpe = json_value["type"]
0887         if tpe in _all_complex_types:
0888             return _all_complex_types[tpe].fromJson(json_value)
0889         elif tpe == 'udt':
0890             return UserDefinedType.fromJson(json_value)
0891         else:
0892             raise ValueError("not supported type: %s" % tpe)
0893 
0894 
0895 # Mapping Python types to Spark SQL DataType
0896 _type_mappings = {
0897     type(None): NullType,
0898     bool: BooleanType,
0899     int: LongType,
0900     float: DoubleType,
0901     str: StringType,
0902     bytearray: BinaryType,
0903     decimal.Decimal: DecimalType,
0904     datetime.date: DateType,
0905     datetime.datetime: TimestampType,
0906     datetime.time: TimestampType,
0907 }
0908 
0909 if sys.version < "3":
0910     _type_mappings.update({
0911         unicode: StringType,
0912         long: LongType,
0913     })
0914 
0915 if sys.version >= "3":
0916     _type_mappings.update({
0917         bytes: BinaryType,
0918     })
0919 
0920 # Mapping Python array types to Spark SQL DataType
0921 # We should be careful here. The size of these types in python depends on C
0922 # implementation. We need to make sure that this conversion does not lose any
0923 # precision. Also, JVM only support signed types, when converting unsigned types,
0924 # keep in mind that it require 1 more bit when stored as signed types.
0925 #
0926 # Reference for C integer size, see:
0927 # ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types <limits.h>.
0928 # Reference for python array typecode, see:
0929 # https://docs.python.org/2/library/array.html
0930 # https://docs.python.org/3.6/library/array.html
0931 # Reference for JVM's supported integral types:
0932 # http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.1
0933 
0934 _array_signed_int_typecode_ctype_mappings = {
0935     'b': ctypes.c_byte,
0936     'h': ctypes.c_short,
0937     'i': ctypes.c_int,
0938     'l': ctypes.c_long,
0939 }
0940 
0941 _array_unsigned_int_typecode_ctype_mappings = {
0942     'B': ctypes.c_ubyte,
0943     'H': ctypes.c_ushort,
0944     'I': ctypes.c_uint,
0945     'L': ctypes.c_ulong
0946 }
0947 
0948 
0949 def _int_size_to_type(size):
0950     """
0951     Return the Catalyst datatype from the size of integers.
0952     """
0953     if size <= 8:
0954         return ByteType
0955     if size <= 16:
0956         return ShortType
0957     if size <= 32:
0958         return IntegerType
0959     if size <= 64:
0960         return LongType
0961 
0962 # The list of all supported array typecodes, is stored here
0963 _array_type_mappings = {
0964     # Warning: Actual properties for float and double in C is not specified in C.
0965     # On almost every system supported by both python and JVM, they are IEEE 754
0966     # single-precision binary floating-point format and IEEE 754 double-precision
0967     # binary floating-point format. And we do assume the same thing here for now.
0968     'f': FloatType,
0969     'd': DoubleType
0970 }
0971 
0972 # compute array typecode mappings for signed integer types
0973 for _typecode in _array_signed_int_typecode_ctype_mappings.keys():
0974     size = ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode]) * 8
0975     dt = _int_size_to_type(size)
0976     if dt is not None:
0977         _array_type_mappings[_typecode] = dt
0978 
0979 # compute array typecode mappings for unsigned integer types
0980 for _typecode in _array_unsigned_int_typecode_ctype_mappings.keys():
0981     # JVM does not have unsigned types, so use signed types that is at least 1
0982     # bit larger to store
0983     size = ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode]) * 8 + 1
0984     dt = _int_size_to_type(size)
0985     if dt is not None:
0986         _array_type_mappings[_typecode] = dt
0987 
0988 # Type code 'u' in Python's array is deprecated since version 3.3, and will be
0989 # removed in version 4.0. See: https://docs.python.org/3/library/array.html
0990 if sys.version_info[0] < 4:
0991     _array_type_mappings['u'] = StringType
0992 
0993 # Type code 'c' are only available at python 2
0994 if sys.version_info[0] < 3:
0995     _array_type_mappings['c'] = StringType
0996 
0997 # SPARK-21465:
0998 # In python2, array of 'L' happened to be mistakenly, just partially supported. To
0999 # avoid breaking user's code, we should keep this partial support. Below is a
1000 # dirty hacking to keep this partial support and pass the unit test.
1001 import platform
1002 if sys.version_info[0] < 3 and platform.python_implementation() != 'PyPy':
1003     if 'L' not in _array_type_mappings.keys():
1004         _array_type_mappings['L'] = LongType
1005         _array_unsigned_int_typecode_ctype_mappings['L'] = ctypes.c_uint
1006 
1007 
1008 def _infer_type(obj):
1009     """Infer the DataType from obj
1010     """
1011     if obj is None:
1012         return NullType()
1013 
1014     if hasattr(obj, '__UDT__'):
1015         return obj.__UDT__
1016 
1017     dataType = _type_mappings.get(type(obj))
1018     if dataType is DecimalType:
1019         # the precision and scale of `obj` may be different from row to row.
1020         return DecimalType(38, 18)
1021     elif dataType is not None:
1022         return dataType()
1023 
1024     if isinstance(obj, dict):
1025         for key, value in obj.items():
1026             if key is not None and value is not None:
1027                 return MapType(_infer_type(key), _infer_type(value), True)
1028         return MapType(NullType(), NullType(), True)
1029     elif isinstance(obj, list):
1030         for v in obj:
1031             if v is not None:
1032                 return ArrayType(_infer_type(obj[0]), True)
1033         return ArrayType(NullType(), True)
1034     elif isinstance(obj, array):
1035         if obj.typecode in _array_type_mappings:
1036             return ArrayType(_array_type_mappings[obj.typecode](), False)
1037         else:
1038             raise TypeError("not supported type: array(%s)" % obj.typecode)
1039     else:
1040         try:
1041             return _infer_schema(obj)
1042         except TypeError:
1043             raise TypeError("not supported type: %s" % type(obj))
1044 
1045 
1046 def _infer_schema(row, names=None):
1047     """Infer the schema from dict/namedtuple/object"""
1048     if isinstance(row, dict):
1049         items = sorted(row.items())
1050 
1051     elif isinstance(row, (tuple, list)):
1052         if hasattr(row, "__fields__"):  # Row
1053             items = zip(row.__fields__, tuple(row))
1054         elif hasattr(row, "_fields"):  # namedtuple
1055             items = zip(row._fields, tuple(row))
1056         else:
1057             if names is None:
1058                 names = ['_%d' % i for i in range(1, len(row) + 1)]
1059             elif len(names) < len(row):
1060                 names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
1061             items = zip(names, row)
1062 
1063     elif hasattr(row, "__dict__"):  # object
1064         items = sorted(row.__dict__.items())
1065 
1066     else:
1067         raise TypeError("Can not infer schema for type: %s" % type(row))
1068 
1069     fields = [StructField(k, _infer_type(v), True) for k, v in items]
1070     return StructType(fields)
1071 
1072 
1073 def _has_nulltype(dt):
1074     """ Return whether there is a NullType in `dt` or not """
1075     if isinstance(dt, StructType):
1076         return any(_has_nulltype(f.dataType) for f in dt.fields)
1077     elif isinstance(dt, ArrayType):
1078         return _has_nulltype((dt.elementType))
1079     elif isinstance(dt, MapType):
1080         return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
1081     else:
1082         return isinstance(dt, NullType)
1083 
1084 
1085 def _merge_type(a, b, name=None):
1086     if name is None:
1087         new_msg = lambda msg: msg
1088         new_name = lambda n: "field %s" % n
1089     else:
1090         new_msg = lambda msg: "%s: %s" % (name, msg)
1091         new_name = lambda n: "field %s in %s" % (n, name)
1092 
1093     if isinstance(a, NullType):
1094         return b
1095     elif isinstance(b, NullType):
1096         return a
1097     elif type(a) is not type(b):
1098         # TODO: type cast (such as int -> long)
1099         raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
1100 
1101     # same type
1102     if isinstance(a, StructType):
1103         nfs = dict((f.name, f.dataType) for f in b.fields)
1104         fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()),
1105                                                   name=new_name(f.name)))
1106                   for f in a.fields]
1107         names = set([f.name for f in fields])
1108         for n in nfs:
1109             if n not in names:
1110                 fields.append(StructField(n, nfs[n]))
1111         return StructType(fields)
1112 
1113     elif isinstance(a, ArrayType):
1114         return ArrayType(_merge_type(a.elementType, b.elementType,
1115                                      name='element in array %s' % name), True)
1116 
1117     elif isinstance(a, MapType):
1118         return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name),
1119                        _merge_type(a.valueType, b.valueType, name='value of map %s' % name),
1120                        True)
1121     else:
1122         return a
1123 
1124 
1125 def _need_converter(dataType):
1126     if isinstance(dataType, StructType):
1127         return True
1128     elif isinstance(dataType, ArrayType):
1129         return _need_converter(dataType.elementType)
1130     elif isinstance(dataType, MapType):
1131         return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
1132     elif isinstance(dataType, NullType):
1133         return True
1134     else:
1135         return False
1136 
1137 
1138 def _create_converter(dataType):
1139     """Create a converter to drop the names of fields in obj """
1140     if not _need_converter(dataType):
1141         return lambda x: x
1142 
1143     if isinstance(dataType, ArrayType):
1144         conv = _create_converter(dataType.elementType)
1145         return lambda row: [conv(v) for v in row]
1146 
1147     elif isinstance(dataType, MapType):
1148         kconv = _create_converter(dataType.keyType)
1149         vconv = _create_converter(dataType.valueType)
1150         return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
1151 
1152     elif isinstance(dataType, NullType):
1153         return lambda x: None
1154 
1155     elif not isinstance(dataType, StructType):
1156         return lambda x: x
1157 
1158     # dataType must be StructType
1159     names = [f.name for f in dataType.fields]
1160     converters = [_create_converter(f.dataType) for f in dataType.fields]
1161     convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
1162 
1163     def convert_struct(obj):
1164         if obj is None:
1165             return
1166 
1167         if isinstance(obj, (tuple, list)):
1168             if convert_fields:
1169                 return tuple(conv(v) for v, conv in zip(obj, converters))
1170             else:
1171                 return tuple(obj)
1172 
1173         if isinstance(obj, dict):
1174             d = obj
1175         elif hasattr(obj, "__dict__"):  # object
1176             d = obj.__dict__
1177         else:
1178             raise TypeError("Unexpected obj type: %s" % type(obj))
1179 
1180         if convert_fields:
1181             return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
1182         else:
1183             return tuple([d.get(name) for name in names])
1184 
1185     return convert_struct
1186 
1187 
1188 _acceptable_types = {
1189     BooleanType: (bool,),
1190     ByteType: (int, long),
1191     ShortType: (int, long),
1192     IntegerType: (int, long),
1193     LongType: (int, long),
1194     FloatType: (float,),
1195     DoubleType: (float,),
1196     DecimalType: (decimal.Decimal,),
1197     StringType: (str, unicode),
1198     BinaryType: (bytearray, bytes),
1199     DateType: (datetime.date, datetime.datetime),
1200     TimestampType: (datetime.datetime,),
1201     ArrayType: (list, tuple, array),
1202     MapType: (dict,),
1203     StructType: (tuple, list, dict),
1204 }
1205 
1206 
1207 def _make_type_verifier(dataType, nullable=True, name=None):
1208     """
1209     Make a verifier that checks the type of obj against dataType and raises a TypeError if they do
1210     not match.
1211 
1212     This verifier also checks the value of obj against datatype and raises a ValueError if it's not
1213     within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is
1214     not checked, so it will become infinity when cast to Java float, if it overflows.
1215 
1216     >>> _make_type_verifier(StructType([]))(None)
1217     >>> _make_type_verifier(StringType())("")
1218     >>> _make_type_verifier(LongType())(0)
1219     >>> _make_type_verifier(LongType())(1 << 64) # doctest: +IGNORE_EXCEPTION_DETAIL
1220     Traceback (most recent call last):
1221         ...
1222     ValueError:...
1223     >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
1224     >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
1225     Traceback (most recent call last):
1226         ...
1227     TypeError:...
1228     >>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
1229     >>> _make_type_verifier(StructType([]))(())
1230     >>> _make_type_verifier(StructType([]))([])
1231     >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
1232     Traceback (most recent call last):
1233         ...
1234     ValueError:...
1235     >>> # Check if numeric values are within the allowed range.
1236     >>> _make_type_verifier(ByteType())(12)
1237     >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
1238     Traceback (most recent call last):
1239         ...
1240     ValueError:...
1241     >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
1242     Traceback (most recent call last):
1243         ...
1244     ValueError:...
1245     >>> _make_type_verifier(
1246     ...     ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
1247     Traceback (most recent call last):
1248         ...
1249     ValueError:...
1250     >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
1251     Traceback (most recent call last):
1252         ...
1253     ValueError:...
1254     >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
1255     >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
1256     Traceback (most recent call last):
1257         ...
1258     ValueError:...
1259     """
1260 
1261     if name is None:
1262         new_msg = lambda msg: msg
1263         new_name = lambda n: "field %s" % n
1264     else:
1265         new_msg = lambda msg: "%s: %s" % (name, msg)
1266         new_name = lambda n: "field %s in %s" % (n, name)
1267 
1268     def verify_nullability(obj):
1269         if obj is None:
1270             if nullable:
1271                 return True
1272             else:
1273                 raise ValueError(new_msg("This field is not nullable, but got None"))
1274         else:
1275             return False
1276 
1277     _type = type(dataType)
1278 
1279     def assert_acceptable_types(obj):
1280         assert _type in _acceptable_types, \
1281             new_msg("unknown datatype: %s for object %r" % (dataType, obj))
1282 
1283     def verify_acceptable_types(obj):
1284         # subclass of them can not be fromInternal in JVM
1285         if type(obj) not in _acceptable_types[_type]:
1286             raise TypeError(new_msg("%s can not accept object %r in type %s"
1287                                     % (dataType, obj, type(obj))))
1288 
1289     if isinstance(dataType, StringType):
1290         # StringType can work with any types
1291         verify_value = lambda _: _
1292 
1293     elif isinstance(dataType, UserDefinedType):
1294         verifier = _make_type_verifier(dataType.sqlType(), name=name)
1295 
1296         def verify_udf(obj):
1297             if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
1298                 raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType)))
1299             verifier(dataType.toInternal(obj))
1300 
1301         verify_value = verify_udf
1302 
1303     elif isinstance(dataType, ByteType):
1304         def verify_byte(obj):
1305             assert_acceptable_types(obj)
1306             verify_acceptable_types(obj)
1307             if obj < -128 or obj > 127:
1308                 raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj))
1309 
1310         verify_value = verify_byte
1311 
1312     elif isinstance(dataType, ShortType):
1313         def verify_short(obj):
1314             assert_acceptable_types(obj)
1315             verify_acceptable_types(obj)
1316             if obj < -32768 or obj > 32767:
1317                 raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj))
1318 
1319         verify_value = verify_short
1320 
1321     elif isinstance(dataType, IntegerType):
1322         def verify_integer(obj):
1323             assert_acceptable_types(obj)
1324             verify_acceptable_types(obj)
1325             if obj < -2147483648 or obj > 2147483647:
1326                 raise ValueError(
1327                     new_msg("object of IntegerType out of range, got: %s" % obj))
1328 
1329         verify_value = verify_integer
1330 
1331     elif isinstance(dataType, LongType):
1332         def verify_long(obj):
1333             assert_acceptable_types(obj)
1334             verify_acceptable_types(obj)
1335             if obj < -9223372036854775808 or obj > 9223372036854775807:
1336                 raise ValueError(
1337                     new_msg("object of LongType out of range, got: %s" % obj))
1338 
1339         verify_value = verify_long
1340 
1341     elif isinstance(dataType, ArrayType):
1342         element_verifier = _make_type_verifier(
1343             dataType.elementType, dataType.containsNull, name="element in array %s" % name)
1344 
1345         def verify_array(obj):
1346             assert_acceptable_types(obj)
1347             verify_acceptable_types(obj)
1348             for i in obj:
1349                 element_verifier(i)
1350 
1351         verify_value = verify_array
1352 
1353     elif isinstance(dataType, MapType):
1354         key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name)
1355         value_verifier = _make_type_verifier(
1356             dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name)
1357 
1358         def verify_map(obj):
1359             assert_acceptable_types(obj)
1360             verify_acceptable_types(obj)
1361             for k, v in obj.items():
1362                 key_verifier(k)
1363                 value_verifier(v)
1364 
1365         verify_value = verify_map
1366 
1367     elif isinstance(dataType, StructType):
1368         verifiers = []
1369         for f in dataType.fields:
1370             verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name))
1371             verifiers.append((f.name, verifier))
1372 
1373         def verify_struct(obj):
1374             assert_acceptable_types(obj)
1375 
1376             if isinstance(obj, dict):
1377                 for f, verifier in verifiers:
1378                     verifier(obj.get(f))
1379             elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
1380                 # the order in obj could be different than dataType.fields
1381                 for f, verifier in verifiers:
1382                     verifier(obj[f])
1383             elif isinstance(obj, (tuple, list)):
1384                 if len(obj) != len(verifiers):
1385                     raise ValueError(
1386                         new_msg("Length of object (%d) does not match with "
1387                                 "length of fields (%d)" % (len(obj), len(verifiers))))
1388                 for v, (_, verifier) in zip(obj, verifiers):
1389                     verifier(v)
1390             elif hasattr(obj, "__dict__"):
1391                 d = obj.__dict__
1392                 for f, verifier in verifiers:
1393                     verifier(d.get(f))
1394             else:
1395                 raise TypeError(new_msg("StructType can not accept object %r in type %s"
1396                                         % (obj, type(obj))))
1397         verify_value = verify_struct
1398 
1399     else:
1400         def verify_default(obj):
1401             assert_acceptable_types(obj)
1402             verify_acceptable_types(obj)
1403 
1404         verify_value = verify_default
1405 
1406     def verify(obj):
1407         if not verify_nullability(obj):
1408             verify_value(obj)
1409 
1410     return verify
1411 
1412 
1413 # This is used to unpickle a Row from JVM
1414 def _create_row_inbound_converter(dataType):
1415     return lambda *a: dataType.fromInternal(a)
1416 
1417 
1418 def _create_row(fields, values):
1419     row = Row(*values)
1420     row.__fields__ = fields
1421     return row
1422 
1423 
1424 class Row(tuple):
1425 
1426     """
1427     A row in :class:`DataFrame`.
1428     The fields in it can be accessed:
1429 
1430     * like attributes (``row.key``)
1431     * like dictionary values (``row[key]``)
1432 
1433     ``key in row`` will search through row keys.
1434 
1435     Row can be used to create a row object by using named arguments.
1436     It is not allowed to omit a named argument to represent that the value is
1437     None or missing. This should be explicitly set to None in this case.
1438 
1439     NOTE: As of Spark 3.0.0, Rows created from named arguments no longer have
1440     field names sorted alphabetically and will be ordered in the position as
1441     entered. To enable sorting for Rows compatible with Spark 2.x, set the
1442     environment variable "PYSPARK_ROW_FIELD_SORTING_ENABLED" to "true". This
1443     option is deprecated and will be removed in future versions of Spark. For
1444     Python versions < 3.6, the order of named arguments is not guaranteed to
1445     be the same as entered, see https://www.python.org/dev/peps/pep-0468. In
1446     this case, a warning will be issued and the Row will fallback to sort the
1447     field names automatically.
1448 
1449     NOTE: Examples with Row in pydocs are run with the environment variable
1450     "PYSPARK_ROW_FIELD_SORTING_ENABLED" set to "true" which results in output
1451     where fields are sorted.
1452 
1453     >>> row = Row(name="Alice", age=11)
1454     >>> row
1455     Row(age=11, name='Alice')
1456     >>> row['name'], row['age']
1457     ('Alice', 11)
1458     >>> row.name, row.age
1459     ('Alice', 11)
1460     >>> 'name' in row
1461     True
1462     >>> 'wrong_key' in row
1463     False
1464 
1465     Row also can be used to create another Row like class, then it
1466     could be used to create Row objects, such as
1467 
1468     >>> Person = Row("name", "age")
1469     >>> Person
1470     <Row('name', 'age')>
1471     >>> 'name' in Person
1472     True
1473     >>> 'wrong_key' in Person
1474     False
1475     >>> Person("Alice", 11)
1476     Row(name='Alice', age=11)
1477 
1478     This form can also be used to create rows as tuple values, i.e. with unnamed
1479     fields. Beware that such Row objects have different equality semantics:
1480 
1481     >>> row1 = Row("Alice", 11)
1482     >>> row2 = Row(name="Alice", age=11)
1483     >>> row1 == row2
1484     False
1485     >>> row3 = Row(a="Alice", b=11)
1486     >>> row1 == row3
1487     True
1488     """
1489 
1490     # Remove after Python < 3.6 dropped, see SPARK-29748
1491     _row_field_sorting_enabled = \
1492         os.environ.get('PYSPARK_ROW_FIELD_SORTING_ENABLED', 'false').lower() == 'true'
1493 
1494     if _row_field_sorting_enabled:
1495         warnings.warn("The environment variable 'PYSPARK_ROW_FIELD_SORTING_ENABLED' "
1496                       "is deprecated and will be removed in future versions of Spark")
1497 
1498     def __new__(cls, *args, **kwargs):
1499         if args and kwargs:
1500             raise ValueError("Can not use both args "
1501                              "and kwargs to create Row")
1502         if kwargs:
1503             if not Row._row_field_sorting_enabled and sys.version_info[:2] < (3, 6):
1504                 warnings.warn("To use named arguments for Python version < 3.6, Row fields will be "
1505                               "automatically sorted. This warning can be skipped by setting the "
1506                               "environment variable 'PYSPARK_ROW_FIELD_SORTING_ENABLED' to 'true'.")
1507                 Row._row_field_sorting_enabled = True
1508 
1509             # create row objects
1510             if Row._row_field_sorting_enabled:
1511                 # Remove after Python < 3.6 dropped, see SPARK-29748
1512                 names = sorted(kwargs.keys())
1513                 row = tuple.__new__(cls, [kwargs[n] for n in names])
1514                 row.__fields__ = names
1515                 row.__from_dict__ = True
1516             else:
1517                 row = tuple.__new__(cls, list(kwargs.values()))
1518                 row.__fields__ = list(kwargs.keys())
1519 
1520             return row
1521         else:
1522             # create row class or objects
1523             return tuple.__new__(cls, args)
1524 
1525     def asDict(self, recursive=False):
1526         """
1527         Return as a dict
1528 
1529         :param recursive: turns the nested Rows to dict (default: False).
1530 
1531         .. note:: If a row contains duplicate field names, e.g., the rows of a join
1532             between two :class:`DataFrame` that both have the fields of same names,
1533             one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
1534             will also return one of the duplicate fields, however returned value might
1535             be different to ``asDict``.
1536 
1537         >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
1538         True
1539         >>> row = Row(key=1, value=Row(name='a', age=2))
1540         >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')}
1541         True
1542         >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
1543         True
1544         """
1545         if not hasattr(self, "__fields__"):
1546             raise TypeError("Cannot convert a Row class into dict")
1547 
1548         if recursive:
1549             def conv(obj):
1550                 if isinstance(obj, Row):
1551                     return obj.asDict(True)
1552                 elif isinstance(obj, list):
1553                     return [conv(o) for o in obj]
1554                 elif isinstance(obj, dict):
1555                     return dict((k, conv(v)) for k, v in obj.items())
1556                 else:
1557                     return obj
1558             return dict(zip(self.__fields__, (conv(o) for o in self)))
1559         else:
1560             return dict(zip(self.__fields__, self))
1561 
1562     def __contains__(self, item):
1563         if hasattr(self, "__fields__"):
1564             return item in self.__fields__
1565         else:
1566             return super(Row, self).__contains__(item)
1567 
1568     # let object acts like class
1569     def __call__(self, *args):
1570         """create new Row object"""
1571         if len(args) > len(self):
1572             raise ValueError("Can not create Row with fields %s, expected %d values "
1573                              "but got %s" % (self, len(self), args))
1574         return _create_row(self, args)
1575 
1576     def __getitem__(self, item):
1577         if isinstance(item, (int, slice)):
1578             return super(Row, self).__getitem__(item)
1579         try:
1580             # it will be slow when it has many fields,
1581             # but this will not be used in normal cases
1582             idx = self.__fields__.index(item)
1583             return super(Row, self).__getitem__(idx)
1584         except IndexError:
1585             raise KeyError(item)
1586         except ValueError:
1587             raise ValueError(item)
1588 
1589     def __getattr__(self, item):
1590         if item.startswith("__"):
1591             raise AttributeError(item)
1592         try:
1593             # it will be slow when it has many fields,
1594             # but this will not be used in normal cases
1595             idx = self.__fields__.index(item)
1596             return self[idx]
1597         except IndexError:
1598             raise AttributeError(item)
1599         except ValueError:
1600             raise AttributeError(item)
1601 
1602     def __setattr__(self, key, value):
1603         if key != '__fields__' and key != "__from_dict__":
1604             raise Exception("Row is read-only")
1605         self.__dict__[key] = value
1606 
1607     def __reduce__(self):
1608         """Returns a tuple so Python knows how to pickle Row."""
1609         if hasattr(self, "__fields__"):
1610             return (_create_row, (self.__fields__, tuple(self)))
1611         else:
1612             return tuple.__reduce__(self)
1613 
1614     def __repr__(self):
1615         """Printable representation of Row used in Python REPL."""
1616         if hasattr(self, "__fields__"):
1617             return "Row(%s)" % ", ".join("%s=%r" % (k, v)
1618                                          for k, v in zip(self.__fields__, tuple(self)))
1619         else:
1620             return "<Row(%s)>" % ", ".join("%r" % field for field in self)
1621 
1622 
1623 class DateConverter(object):
1624     def can_convert(self, obj):
1625         return isinstance(obj, datetime.date)
1626 
1627     def convert(self, obj, gateway_client):
1628         Date = JavaClass("java.sql.Date", gateway_client)
1629         return Date.valueOf(obj.strftime("%Y-%m-%d"))
1630 
1631 
1632 class DatetimeConverter(object):
1633     def can_convert(self, obj):
1634         return isinstance(obj, datetime.datetime)
1635 
1636     def convert(self, obj, gateway_client):
1637         Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
1638         seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo
1639                    else time.mktime(obj.timetuple()))
1640         t = Timestamp(int(seconds) * 1000)
1641         t.setNanos(obj.microsecond * 1000)
1642         return t
1643 
1644 # datetime is a subclass of date, we should register DatetimeConverter first
1645 register_input_converter(DatetimeConverter())
1646 register_input_converter(DateConverter())
1647 
1648 
1649 def _test():
1650     import doctest
1651     from pyspark.context import SparkContext
1652     from pyspark.sql import SparkSession
1653     globs = globals()
1654     sc = SparkContext('local[4]', 'PythonTest')
1655     globs['sc'] = sc
1656     globs['spark'] = SparkSession.builder.getOrCreate()
1657     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
1658     globs['sc'].stop()
1659     if failure_count:
1660         sys.exit(-1)
1661 
1662 
1663 if __name__ == "__main__":
1664     _test()