0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import sys
0020 import decimal
0021 import time
0022 import datetime
0023 import calendar
0024 import json
0025 import re
0026 import base64
0027 from array import array
0028 import ctypes
0029 import warnings
0030
0031 if sys.version >= "3":
0032 long = int
0033 basestring = unicode = str
0034
0035 from py4j.protocol import register_input_converter
0036 from py4j.java_gateway import JavaClass
0037
0038 from pyspark import SparkContext
0039 from pyspark.serializers import CloudPickleSerializer
0040
0041 __all__ = [
0042 "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
0043 "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
0044 "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
0045
0046
0047 class DataType(object):
0048 """Base class for data types."""
0049
0050 def __repr__(self):
0051 return self.__class__.__name__
0052
0053 def __hash__(self):
0054 return hash(str(self))
0055
0056 def __eq__(self, other):
0057 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
0058
0059 def __ne__(self, other):
0060 return not self.__eq__(other)
0061
0062 @classmethod
0063 def typeName(cls):
0064 return cls.__name__[:-4].lower()
0065
0066 def simpleString(self):
0067 return self.typeName()
0068
0069 def jsonValue(self):
0070 return self.typeName()
0071
0072 def json(self):
0073 return json.dumps(self.jsonValue(),
0074 separators=(',', ':'),
0075 sort_keys=True)
0076
0077 def needConversion(self):
0078 """
0079 Does this type needs conversion between Python object and internal SQL object.
0080
0081 This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
0082 """
0083 return False
0084
0085 def toInternal(self, obj):
0086 """
0087 Converts a Python object into an internal SQL object.
0088 """
0089 return obj
0090
0091 def fromInternal(self, obj):
0092 """
0093 Converts an internal SQL object into a native Python object.
0094 """
0095 return obj
0096
0097
0098
0099
0100 class DataTypeSingleton(type):
0101 """Metaclass for DataType"""
0102
0103 _instances = {}
0104
0105 def __call__(cls):
0106 if cls not in cls._instances:
0107 cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
0108 return cls._instances[cls]
0109
0110
0111 class NullType(DataType):
0112 """Null type.
0113
0114 The data type representing None, used for the types that cannot be inferred.
0115 """
0116
0117 __metaclass__ = DataTypeSingleton
0118
0119
0120 class AtomicType(DataType):
0121 """An internal type used to represent everything that is not
0122 null, UDTs, arrays, structs, and maps."""
0123
0124
0125 class NumericType(AtomicType):
0126 """Numeric data types.
0127 """
0128
0129
0130 class IntegralType(NumericType):
0131 """Integral data types.
0132 """
0133
0134 __metaclass__ = DataTypeSingleton
0135
0136
0137 class FractionalType(NumericType):
0138 """Fractional data types.
0139 """
0140
0141
0142 class StringType(AtomicType):
0143 """String data type.
0144 """
0145
0146 __metaclass__ = DataTypeSingleton
0147
0148
0149 class BinaryType(AtomicType):
0150 """Binary (byte array) data type.
0151 """
0152
0153 __metaclass__ = DataTypeSingleton
0154
0155
0156 class BooleanType(AtomicType):
0157 """Boolean data type.
0158 """
0159
0160 __metaclass__ = DataTypeSingleton
0161
0162
0163 class DateType(AtomicType):
0164 """Date (datetime.date) data type.
0165 """
0166
0167 __metaclass__ = DataTypeSingleton
0168
0169 EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
0170
0171 def needConversion(self):
0172 return True
0173
0174 def toInternal(self, d):
0175 if d is not None:
0176 return d.toordinal() - self.EPOCH_ORDINAL
0177
0178 def fromInternal(self, v):
0179 if v is not None:
0180 return datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
0181
0182
0183 class TimestampType(AtomicType):
0184 """Timestamp (datetime.datetime) data type.
0185 """
0186
0187 __metaclass__ = DataTypeSingleton
0188
0189 def needConversion(self):
0190 return True
0191
0192 def toInternal(self, dt):
0193 if dt is not None:
0194 seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
0195 else time.mktime(dt.timetuple()))
0196 return int(seconds) * 1000000 + dt.microsecond
0197
0198 def fromInternal(self, ts):
0199 if ts is not None:
0200
0201 return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000)
0202
0203
0204 class DecimalType(FractionalType):
0205 """Decimal (decimal.Decimal) data type.
0206
0207 The DecimalType must have fixed precision (the maximum total number of digits)
0208 and scale (the number of digits on the right of dot). For example, (5, 2) can
0209 support the value from [-999.99 to 999.99].
0210
0211 The precision can be up to 38, the scale must be less or equal to precision.
0212
0213 When creating a DecimalType, the default precision and scale is (10, 0). When inferring
0214 schema from decimal.Decimal objects, it will be DecimalType(38, 18).
0215
0216 :param precision: the maximum (i.e. total) number of digits (default: 10)
0217 :param scale: the number of digits on right side of dot. (default: 0)
0218 """
0219
0220 def __init__(self, precision=10, scale=0):
0221 self.precision = precision
0222 self.scale = scale
0223 self.hasPrecisionInfo = True
0224
0225 def simpleString(self):
0226 return "decimal(%d,%d)" % (self.precision, self.scale)
0227
0228 def jsonValue(self):
0229 return "decimal(%d,%d)" % (self.precision, self.scale)
0230
0231 def __repr__(self):
0232 return "DecimalType(%d,%d)" % (self.precision, self.scale)
0233
0234
0235 class DoubleType(FractionalType):
0236 """Double data type, representing double precision floats.
0237 """
0238
0239 __metaclass__ = DataTypeSingleton
0240
0241
0242 class FloatType(FractionalType):
0243 """Float data type, representing single precision floats.
0244 """
0245
0246 __metaclass__ = DataTypeSingleton
0247
0248
0249 class ByteType(IntegralType):
0250 """Byte data type, i.e. a signed integer in a single byte.
0251 """
0252 def simpleString(self):
0253 return 'tinyint'
0254
0255
0256 class IntegerType(IntegralType):
0257 """Int data type, i.e. a signed 32-bit integer.
0258 """
0259 def simpleString(self):
0260 return 'int'
0261
0262
0263 class LongType(IntegralType):
0264 """Long data type, i.e. a signed 64-bit integer.
0265
0266 If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
0267 please use :class:`DecimalType`.
0268 """
0269 def simpleString(self):
0270 return 'bigint'
0271
0272
0273 class ShortType(IntegralType):
0274 """Short data type, i.e. a signed 16-bit integer.
0275 """
0276 def simpleString(self):
0277 return 'smallint'
0278
0279
0280 class ArrayType(DataType):
0281 """Array data type.
0282
0283 :param elementType: :class:`DataType` of each element in the array.
0284 :param containsNull: boolean, whether the array can contain null (None) values.
0285 """
0286
0287 def __init__(self, elementType, containsNull=True):
0288 """
0289 >>> ArrayType(StringType()) == ArrayType(StringType(), True)
0290 True
0291 >>> ArrayType(StringType(), False) == ArrayType(StringType())
0292 False
0293 """
0294 assert isinstance(elementType, DataType),\
0295 "elementType %s should be an instance of %s" % (elementType, DataType)
0296 self.elementType = elementType
0297 self.containsNull = containsNull
0298
0299 def simpleString(self):
0300 return 'array<%s>' % self.elementType.simpleString()
0301
0302 def __repr__(self):
0303 return "ArrayType(%s,%s)" % (self.elementType,
0304 str(self.containsNull).lower())
0305
0306 def jsonValue(self):
0307 return {"type": self.typeName(),
0308 "elementType": self.elementType.jsonValue(),
0309 "containsNull": self.containsNull}
0310
0311 @classmethod
0312 def fromJson(cls, json):
0313 return ArrayType(_parse_datatype_json_value(json["elementType"]),
0314 json["containsNull"])
0315
0316 def needConversion(self):
0317 return self.elementType.needConversion()
0318
0319 def toInternal(self, obj):
0320 if not self.needConversion():
0321 return obj
0322 return obj and [self.elementType.toInternal(v) for v in obj]
0323
0324 def fromInternal(self, obj):
0325 if not self.needConversion():
0326 return obj
0327 return obj and [self.elementType.fromInternal(v) for v in obj]
0328
0329
0330 class MapType(DataType):
0331 """Map data type.
0332
0333 :param keyType: :class:`DataType` of the keys in the map.
0334 :param valueType: :class:`DataType` of the values in the map.
0335 :param valueContainsNull: indicates whether values can contain null (None) values.
0336
0337 Keys in a map data type are not allowed to be null (None).
0338 """
0339
0340 def __init__(self, keyType, valueType, valueContainsNull=True):
0341 """
0342 >>> (MapType(StringType(), IntegerType())
0343 ... == MapType(StringType(), IntegerType(), True))
0344 True
0345 >>> (MapType(StringType(), IntegerType(), False)
0346 ... == MapType(StringType(), FloatType()))
0347 False
0348 """
0349 assert isinstance(keyType, DataType),\
0350 "keyType %s should be an instance of %s" % (keyType, DataType)
0351 assert isinstance(valueType, DataType),\
0352 "valueType %s should be an instance of %s" % (valueType, DataType)
0353 self.keyType = keyType
0354 self.valueType = valueType
0355 self.valueContainsNull = valueContainsNull
0356
0357 def simpleString(self):
0358 return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
0359
0360 def __repr__(self):
0361 return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
0362 str(self.valueContainsNull).lower())
0363
0364 def jsonValue(self):
0365 return {"type": self.typeName(),
0366 "keyType": self.keyType.jsonValue(),
0367 "valueType": self.valueType.jsonValue(),
0368 "valueContainsNull": self.valueContainsNull}
0369
0370 @classmethod
0371 def fromJson(cls, json):
0372 return MapType(_parse_datatype_json_value(json["keyType"]),
0373 _parse_datatype_json_value(json["valueType"]),
0374 json["valueContainsNull"])
0375
0376 def needConversion(self):
0377 return self.keyType.needConversion() or self.valueType.needConversion()
0378
0379 def toInternal(self, obj):
0380 if not self.needConversion():
0381 return obj
0382 return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
0383 for k, v in obj.items())
0384
0385 def fromInternal(self, obj):
0386 if not self.needConversion():
0387 return obj
0388 return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
0389 for k, v in obj.items())
0390
0391
0392 class StructField(DataType):
0393 """A field in :class:`StructType`.
0394
0395 :param name: string, name of the field.
0396 :param dataType: :class:`DataType` of the field.
0397 :param nullable: boolean, whether the field can be null (None) or not.
0398 :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
0399 """
0400
0401 def __init__(self, name, dataType, nullable=True, metadata=None):
0402 """
0403 >>> (StructField("f1", StringType(), True)
0404 ... == StructField("f1", StringType(), True))
0405 True
0406 >>> (StructField("f1", StringType(), True)
0407 ... == StructField("f2", StringType(), True))
0408 False
0409 """
0410 assert isinstance(dataType, DataType),\
0411 "dataType %s should be an instance of %s" % (dataType, DataType)
0412 assert isinstance(name, basestring), "field name %s should be string" % (name)
0413 if not isinstance(name, str):
0414 name = name.encode('utf-8')
0415 self.name = name
0416 self.dataType = dataType
0417 self.nullable = nullable
0418 self.metadata = metadata or {}
0419
0420 def simpleString(self):
0421 return '%s:%s' % (self.name, self.dataType.simpleString())
0422
0423 def __repr__(self):
0424 return "StructField(%s,%s,%s)" % (self.name, self.dataType,
0425 str(self.nullable).lower())
0426
0427 def jsonValue(self):
0428 return {"name": self.name,
0429 "type": self.dataType.jsonValue(),
0430 "nullable": self.nullable,
0431 "metadata": self.metadata}
0432
0433 @classmethod
0434 def fromJson(cls, json):
0435 return StructField(json["name"],
0436 _parse_datatype_json_value(json["type"]),
0437 json["nullable"],
0438 json["metadata"])
0439
0440 def needConversion(self):
0441 return self.dataType.needConversion()
0442
0443 def toInternal(self, obj):
0444 return self.dataType.toInternal(obj)
0445
0446 def fromInternal(self, obj):
0447 return self.dataType.fromInternal(obj)
0448
0449 def typeName(self):
0450 raise TypeError(
0451 "StructField does not have typeName. "
0452 "Use typeName on its type explicitly instead.")
0453
0454
0455 class StructType(DataType):
0456 """Struct type, consisting of a list of :class:`StructField`.
0457
0458 This is the data type representing a :class:`Row`.
0459
0460 Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s.
0461 A contained :class:`StructField` can be accessed by its name or position.
0462
0463 >>> struct1 = StructType([StructField("f1", StringType(), True)])
0464 >>> struct1["f1"]
0465 StructField(f1,StringType,true)
0466 >>> struct1[0]
0467 StructField(f1,StringType,true)
0468 """
0469 def __init__(self, fields=None):
0470 """
0471 >>> struct1 = StructType([StructField("f1", StringType(), True)])
0472 >>> struct2 = StructType([StructField("f1", StringType(), True)])
0473 >>> struct1 == struct2
0474 True
0475 >>> struct1 = StructType([StructField("f1", StringType(), True)])
0476 >>> struct2 = StructType([StructField("f1", StringType(), True),
0477 ... StructField("f2", IntegerType(), False)])
0478 >>> struct1 == struct2
0479 False
0480 """
0481 if not fields:
0482 self.fields = []
0483 self.names = []
0484 else:
0485 self.fields = fields
0486 self.names = [f.name for f in fields]
0487 assert all(isinstance(f, StructField) for f in fields),\
0488 "fields should be a list of StructField"
0489
0490 self._needConversion = [f.needConversion() for f in self]
0491 self._needSerializeAnyField = any(self._needConversion)
0492
0493 def add(self, field, data_type=None, nullable=True, metadata=None):
0494 """
0495 Construct a StructType by adding new elements to it, to define the schema.
0496 The method accepts either:
0497
0498 a) A single parameter which is a StructField object.
0499 b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
0500 metadata(optional). The data_type parameter may be either a String or a
0501 DataType object.
0502
0503 >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
0504 >>> struct2 = StructType([StructField("f1", StringType(), True), \\
0505 ... StructField("f2", StringType(), True, None)])
0506 >>> struct1 == struct2
0507 True
0508 >>> struct1 = StructType().add(StructField("f1", StringType(), True))
0509 >>> struct2 = StructType([StructField("f1", StringType(), True)])
0510 >>> struct1 == struct2
0511 True
0512 >>> struct1 = StructType().add("f1", "string", True)
0513 >>> struct2 = StructType([StructField("f1", StringType(), True)])
0514 >>> struct1 == struct2
0515 True
0516
0517 :param field: Either the name of the field or a StructField object
0518 :param data_type: If present, the DataType of the StructField to create
0519 :param nullable: Whether the field to add should be nullable (default True)
0520 :param metadata: Any additional metadata (default None)
0521 :return: a new updated StructType
0522 """
0523 if isinstance(field, StructField):
0524 self.fields.append(field)
0525 self.names.append(field.name)
0526 else:
0527 if isinstance(field, str) and data_type is None:
0528 raise ValueError("Must specify DataType if passing name of struct_field to create.")
0529
0530 if isinstance(data_type, str):
0531 data_type_f = _parse_datatype_json_value(data_type)
0532 else:
0533 data_type_f = data_type
0534 self.fields.append(StructField(field, data_type_f, nullable, metadata))
0535 self.names.append(field)
0536
0537 self._needConversion = [f.needConversion() for f in self]
0538 self._needSerializeAnyField = any(self._needConversion)
0539 return self
0540
0541 def __iter__(self):
0542 """Iterate the fields"""
0543 return iter(self.fields)
0544
0545 def __len__(self):
0546 """Return the number of fields."""
0547 return len(self.fields)
0548
0549 def __getitem__(self, key):
0550 """Access fields by name or slice."""
0551 if isinstance(key, str):
0552 for field in self:
0553 if field.name == key:
0554 return field
0555 raise KeyError('No StructField named {0}'.format(key))
0556 elif isinstance(key, int):
0557 try:
0558 return self.fields[key]
0559 except IndexError:
0560 raise IndexError('StructType index out of range')
0561 elif isinstance(key, slice):
0562 return StructType(self.fields[key])
0563 else:
0564 raise TypeError('StructType keys should be strings, integers or slices')
0565
0566 def simpleString(self):
0567 return 'struct<%s>' % (','.join(f.simpleString() for f in self))
0568
0569 def __repr__(self):
0570 return ("StructType(List(%s))" %
0571 ",".join(str(field) for field in self))
0572
0573 def jsonValue(self):
0574 return {"type": self.typeName(),
0575 "fields": [f.jsonValue() for f in self]}
0576
0577 @classmethod
0578 def fromJson(cls, json):
0579 return StructType([StructField.fromJson(f) for f in json["fields"]])
0580
0581 def fieldNames(self):
0582 """
0583 Returns all field names in a list.
0584
0585 >>> struct = StructType([StructField("f1", StringType(), True)])
0586 >>> struct.fieldNames()
0587 ['f1']
0588 """
0589 return list(self.names)
0590
0591 def needConversion(self):
0592
0593 return True
0594
0595 def toInternal(self, obj):
0596 if obj is None:
0597 return
0598
0599 if self._needSerializeAnyField:
0600
0601 if isinstance(obj, dict):
0602 return tuple(f.toInternal(obj.get(n)) if c else obj.get(n)
0603 for n, f, c in zip(self.names, self.fields, self._needConversion))
0604 elif isinstance(obj, (tuple, list)):
0605 return tuple(f.toInternal(v) if c else v
0606 for f, v, c in zip(self.fields, obj, self._needConversion))
0607 elif hasattr(obj, "__dict__"):
0608 d = obj.__dict__
0609 return tuple(f.toInternal(d.get(n)) if c else d.get(n)
0610 for n, f, c in zip(self.names, self.fields, self._needConversion))
0611 else:
0612 raise ValueError("Unexpected tuple %r with StructType" % obj)
0613 else:
0614 if isinstance(obj, dict):
0615 return tuple(obj.get(n) for n in self.names)
0616 elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
0617 return tuple(obj[n] for n in self.names)
0618 elif isinstance(obj, (list, tuple)):
0619 return tuple(obj)
0620 elif hasattr(obj, "__dict__"):
0621 d = obj.__dict__
0622 return tuple(d.get(n) for n in self.names)
0623 else:
0624 raise ValueError("Unexpected tuple %r with StructType" % obj)
0625
0626 def fromInternal(self, obj):
0627 if obj is None:
0628 return
0629 if isinstance(obj, Row):
0630
0631 return obj
0632 if self._needSerializeAnyField:
0633
0634 values = [f.fromInternal(v) if c else v
0635 for f, v, c in zip(self.fields, obj, self._needConversion)]
0636 else:
0637 values = obj
0638 return _create_row(self.names, values)
0639
0640
0641 class UserDefinedType(DataType):
0642 """User-defined type (UDT).
0643
0644 .. note:: WARN: Spark Internal Use Only
0645 """
0646
0647 @classmethod
0648 def typeName(cls):
0649 return cls.__name__.lower()
0650
0651 @classmethod
0652 def sqlType(cls):
0653 """
0654 Underlying SQL storage type for this UDT.
0655 """
0656 raise NotImplementedError("UDT must implement sqlType().")
0657
0658 @classmethod
0659 def module(cls):
0660 """
0661 The Python module of the UDT.
0662 """
0663 raise NotImplementedError("UDT must implement module().")
0664
0665 @classmethod
0666 def scalaUDT(cls):
0667 """
0668 The class name of the paired Scala UDT (could be '', if there
0669 is no corresponding one).
0670 """
0671 return ''
0672
0673 def needConversion(self):
0674 return True
0675
0676 @classmethod
0677 def _cachedSqlType(cls):
0678 """
0679 Cache the sqlType() into class, because it's heavily used in `toInternal`.
0680 """
0681 if not hasattr(cls, "_cached_sql_type"):
0682 cls._cached_sql_type = cls.sqlType()
0683 return cls._cached_sql_type
0684
0685 def toInternal(self, obj):
0686 if obj is not None:
0687 return self._cachedSqlType().toInternal(self.serialize(obj))
0688
0689 def fromInternal(self, obj):
0690 v = self._cachedSqlType().fromInternal(obj)
0691 if v is not None:
0692 return self.deserialize(v)
0693
0694 def serialize(self, obj):
0695 """
0696 Converts a user-type object into a SQL datum.
0697 """
0698 raise NotImplementedError("UDT must implement toInternal().")
0699
0700 def deserialize(self, datum):
0701 """
0702 Converts a SQL datum into a user-type object.
0703 """
0704 raise NotImplementedError("UDT must implement fromInternal().")
0705
0706 def simpleString(self):
0707 return 'udt'
0708
0709 def json(self):
0710 return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
0711
0712 def jsonValue(self):
0713 if self.scalaUDT():
0714 assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
0715 schema = {
0716 "type": "udt",
0717 "class": self.scalaUDT(),
0718 "pyClass": "%s.%s" % (self.module(), type(self).__name__),
0719 "sqlType": self.sqlType().jsonValue()
0720 }
0721 else:
0722 ser = CloudPickleSerializer()
0723 b = ser.dumps(type(self))
0724 schema = {
0725 "type": "udt",
0726 "pyClass": "%s.%s" % (self.module(), type(self).__name__),
0727 "serializedClass": base64.b64encode(b).decode('utf8'),
0728 "sqlType": self.sqlType().jsonValue()
0729 }
0730 return schema
0731
0732 @classmethod
0733 def fromJson(cls, json):
0734 pyUDT = str(json["pyClass"])
0735 split = pyUDT.rfind(".")
0736 pyModule = pyUDT[:split]
0737 pyClass = pyUDT[split+1:]
0738 m = __import__(pyModule, globals(), locals(), [pyClass])
0739 if not hasattr(m, pyClass):
0740 s = base64.b64decode(json['serializedClass'].encode('utf-8'))
0741 UDT = CloudPickleSerializer().loads(s)
0742 else:
0743 UDT = getattr(m, pyClass)
0744 return UDT()
0745
0746 def __eq__(self, other):
0747 return type(self) == type(other)
0748
0749
0750 _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
0751 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, NullType]
0752 _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
0753 _all_complex_types = dict((v.typeName(), v)
0754 for v in [ArrayType, MapType, StructType])
0755
0756
0757 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
0758
0759
0760 def _parse_datatype_string(s):
0761 """
0762 Parses the given data type string to a :class:`DataType`. The data type string format equals
0763 :class:`DataType.simpleString`, except that the top level struct type can omit
0764 the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
0765 of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
0766 for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
0767 string and case-insensitive strings.
0768
0769 >>> _parse_datatype_string("int ")
0770 IntegerType
0771 >>> _parse_datatype_string("INT ")
0772 IntegerType
0773 >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
0774 StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
0775 >>> _parse_datatype_string("a DOUBLE, b STRING")
0776 StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
0777 >>> _parse_datatype_string("a: array< short>")
0778 StructType(List(StructField(a,ArrayType(ShortType,true),true)))
0779 >>> _parse_datatype_string(" map<string , string > ")
0780 MapType(StringType,StringType,true)
0781
0782 >>> # Error cases
0783 >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
0784 Traceback (most recent call last):
0785 ...
0786 ParseException:...
0787 >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
0788 Traceback (most recent call last):
0789 ...
0790 ParseException:...
0791 >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
0792 Traceback (most recent call last):
0793 ...
0794 ParseException:...
0795 >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
0796 Traceback (most recent call last):
0797 ...
0798 ParseException:...
0799 """
0800 sc = SparkContext._active_spark_context
0801
0802 def from_ddl_schema(type_str):
0803 return _parse_datatype_json_string(
0804 sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
0805
0806 def from_ddl_datatype(type_str):
0807 return _parse_datatype_json_string(
0808 sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
0809
0810 try:
0811
0812 return from_ddl_schema(s)
0813 except Exception as e:
0814 try:
0815
0816 return from_ddl_datatype(s)
0817 except:
0818 try:
0819
0820 return from_ddl_datatype("struct<%s>" % s.strip())
0821 except:
0822 raise e
0823
0824
0825 def _parse_datatype_json_string(json_string):
0826 """Parses the given data type JSON string.
0827 >>> import pickle
0828 >>> def check_datatype(datatype):
0829 ... pickled = pickle.loads(pickle.dumps(datatype))
0830 ... assert datatype == pickled
0831 ... scala_datatype = spark._jsparkSession.parseDataType(datatype.json())
0832 ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
0833 ... assert datatype == python_datatype
0834 >>> for cls in _all_atomic_types.values():
0835 ... check_datatype(cls())
0836
0837 >>> # Simple ArrayType.
0838 >>> simple_arraytype = ArrayType(StringType(), True)
0839 >>> check_datatype(simple_arraytype)
0840
0841 >>> # Simple MapType.
0842 >>> simple_maptype = MapType(StringType(), LongType())
0843 >>> check_datatype(simple_maptype)
0844
0845 >>> # Simple StructType.
0846 >>> simple_structtype = StructType([
0847 ... StructField("a", DecimalType(), False),
0848 ... StructField("b", BooleanType(), True),
0849 ... StructField("c", LongType(), True),
0850 ... StructField("d", BinaryType(), False)])
0851 >>> check_datatype(simple_structtype)
0852
0853 >>> # Complex StructType.
0854 >>> complex_structtype = StructType([
0855 ... StructField("simpleArray", simple_arraytype, True),
0856 ... StructField("simpleMap", simple_maptype, True),
0857 ... StructField("simpleStruct", simple_structtype, True),
0858 ... StructField("boolean", BooleanType(), False),
0859 ... StructField("withMeta", DoubleType(), False, {"name": "age"})])
0860 >>> check_datatype(complex_structtype)
0861
0862 >>> # Complex ArrayType.
0863 >>> complex_arraytype = ArrayType(complex_structtype, True)
0864 >>> check_datatype(complex_arraytype)
0865
0866 >>> # Complex MapType.
0867 >>> complex_maptype = MapType(complex_structtype,
0868 ... complex_arraytype, False)
0869 >>> check_datatype(complex_maptype)
0870 """
0871 return _parse_datatype_json_value(json.loads(json_string))
0872
0873
0874 def _parse_datatype_json_value(json_value):
0875 if not isinstance(json_value, dict):
0876 if json_value in _all_atomic_types.keys():
0877 return _all_atomic_types[json_value]()
0878 elif json_value == 'decimal':
0879 return DecimalType()
0880 elif _FIXED_DECIMAL.match(json_value):
0881 m = _FIXED_DECIMAL.match(json_value)
0882 return DecimalType(int(m.group(1)), int(m.group(2)))
0883 else:
0884 raise ValueError("Could not parse datatype: %s" % json_value)
0885 else:
0886 tpe = json_value["type"]
0887 if tpe in _all_complex_types:
0888 return _all_complex_types[tpe].fromJson(json_value)
0889 elif tpe == 'udt':
0890 return UserDefinedType.fromJson(json_value)
0891 else:
0892 raise ValueError("not supported type: %s" % tpe)
0893
0894
0895
0896 _type_mappings = {
0897 type(None): NullType,
0898 bool: BooleanType,
0899 int: LongType,
0900 float: DoubleType,
0901 str: StringType,
0902 bytearray: BinaryType,
0903 decimal.Decimal: DecimalType,
0904 datetime.date: DateType,
0905 datetime.datetime: TimestampType,
0906 datetime.time: TimestampType,
0907 }
0908
0909 if sys.version < "3":
0910 _type_mappings.update({
0911 unicode: StringType,
0912 long: LongType,
0913 })
0914
0915 if sys.version >= "3":
0916 _type_mappings.update({
0917 bytes: BinaryType,
0918 })
0919
0920
0921
0922
0923
0924
0925
0926
0927
0928
0929
0930
0931
0932
0933
0934 _array_signed_int_typecode_ctype_mappings = {
0935 'b': ctypes.c_byte,
0936 'h': ctypes.c_short,
0937 'i': ctypes.c_int,
0938 'l': ctypes.c_long,
0939 }
0940
0941 _array_unsigned_int_typecode_ctype_mappings = {
0942 'B': ctypes.c_ubyte,
0943 'H': ctypes.c_ushort,
0944 'I': ctypes.c_uint,
0945 'L': ctypes.c_ulong
0946 }
0947
0948
0949 def _int_size_to_type(size):
0950 """
0951 Return the Catalyst datatype from the size of integers.
0952 """
0953 if size <= 8:
0954 return ByteType
0955 if size <= 16:
0956 return ShortType
0957 if size <= 32:
0958 return IntegerType
0959 if size <= 64:
0960 return LongType
0961
0962
0963 _array_type_mappings = {
0964
0965
0966
0967
0968 'f': FloatType,
0969 'd': DoubleType
0970 }
0971
0972
0973 for _typecode in _array_signed_int_typecode_ctype_mappings.keys():
0974 size = ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode]) * 8
0975 dt = _int_size_to_type(size)
0976 if dt is not None:
0977 _array_type_mappings[_typecode] = dt
0978
0979
0980 for _typecode in _array_unsigned_int_typecode_ctype_mappings.keys():
0981
0982
0983 size = ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode]) * 8 + 1
0984 dt = _int_size_to_type(size)
0985 if dt is not None:
0986 _array_type_mappings[_typecode] = dt
0987
0988
0989
0990 if sys.version_info[0] < 4:
0991 _array_type_mappings['u'] = StringType
0992
0993
0994 if sys.version_info[0] < 3:
0995 _array_type_mappings['c'] = StringType
0996
0997
0998
0999
1000
1001 import platform
1002 if sys.version_info[0] < 3 and platform.python_implementation() != 'PyPy':
1003 if 'L' not in _array_type_mappings.keys():
1004 _array_type_mappings['L'] = LongType
1005 _array_unsigned_int_typecode_ctype_mappings['L'] = ctypes.c_uint
1006
1007
1008 def _infer_type(obj):
1009 """Infer the DataType from obj
1010 """
1011 if obj is None:
1012 return NullType()
1013
1014 if hasattr(obj, '__UDT__'):
1015 return obj.__UDT__
1016
1017 dataType = _type_mappings.get(type(obj))
1018 if dataType is DecimalType:
1019
1020 return DecimalType(38, 18)
1021 elif dataType is not None:
1022 return dataType()
1023
1024 if isinstance(obj, dict):
1025 for key, value in obj.items():
1026 if key is not None and value is not None:
1027 return MapType(_infer_type(key), _infer_type(value), True)
1028 return MapType(NullType(), NullType(), True)
1029 elif isinstance(obj, list):
1030 for v in obj:
1031 if v is not None:
1032 return ArrayType(_infer_type(obj[0]), True)
1033 return ArrayType(NullType(), True)
1034 elif isinstance(obj, array):
1035 if obj.typecode in _array_type_mappings:
1036 return ArrayType(_array_type_mappings[obj.typecode](), False)
1037 else:
1038 raise TypeError("not supported type: array(%s)" % obj.typecode)
1039 else:
1040 try:
1041 return _infer_schema(obj)
1042 except TypeError:
1043 raise TypeError("not supported type: %s" % type(obj))
1044
1045
1046 def _infer_schema(row, names=None):
1047 """Infer the schema from dict/namedtuple/object"""
1048 if isinstance(row, dict):
1049 items = sorted(row.items())
1050
1051 elif isinstance(row, (tuple, list)):
1052 if hasattr(row, "__fields__"):
1053 items = zip(row.__fields__, tuple(row))
1054 elif hasattr(row, "_fields"):
1055 items = zip(row._fields, tuple(row))
1056 else:
1057 if names is None:
1058 names = ['_%d' % i for i in range(1, len(row) + 1)]
1059 elif len(names) < len(row):
1060 names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
1061 items = zip(names, row)
1062
1063 elif hasattr(row, "__dict__"):
1064 items = sorted(row.__dict__.items())
1065
1066 else:
1067 raise TypeError("Can not infer schema for type: %s" % type(row))
1068
1069 fields = [StructField(k, _infer_type(v), True) for k, v in items]
1070 return StructType(fields)
1071
1072
1073 def _has_nulltype(dt):
1074 """ Return whether there is a NullType in `dt` or not """
1075 if isinstance(dt, StructType):
1076 return any(_has_nulltype(f.dataType) for f in dt.fields)
1077 elif isinstance(dt, ArrayType):
1078 return _has_nulltype((dt.elementType))
1079 elif isinstance(dt, MapType):
1080 return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
1081 else:
1082 return isinstance(dt, NullType)
1083
1084
1085 def _merge_type(a, b, name=None):
1086 if name is None:
1087 new_msg = lambda msg: msg
1088 new_name = lambda n: "field %s" % n
1089 else:
1090 new_msg = lambda msg: "%s: %s" % (name, msg)
1091 new_name = lambda n: "field %s in %s" % (n, name)
1092
1093 if isinstance(a, NullType):
1094 return b
1095 elif isinstance(b, NullType):
1096 return a
1097 elif type(a) is not type(b):
1098
1099 raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
1100
1101
1102 if isinstance(a, StructType):
1103 nfs = dict((f.name, f.dataType) for f in b.fields)
1104 fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()),
1105 name=new_name(f.name)))
1106 for f in a.fields]
1107 names = set([f.name for f in fields])
1108 for n in nfs:
1109 if n not in names:
1110 fields.append(StructField(n, nfs[n]))
1111 return StructType(fields)
1112
1113 elif isinstance(a, ArrayType):
1114 return ArrayType(_merge_type(a.elementType, b.elementType,
1115 name='element in array %s' % name), True)
1116
1117 elif isinstance(a, MapType):
1118 return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name),
1119 _merge_type(a.valueType, b.valueType, name='value of map %s' % name),
1120 True)
1121 else:
1122 return a
1123
1124
1125 def _need_converter(dataType):
1126 if isinstance(dataType, StructType):
1127 return True
1128 elif isinstance(dataType, ArrayType):
1129 return _need_converter(dataType.elementType)
1130 elif isinstance(dataType, MapType):
1131 return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
1132 elif isinstance(dataType, NullType):
1133 return True
1134 else:
1135 return False
1136
1137
1138 def _create_converter(dataType):
1139 """Create a converter to drop the names of fields in obj """
1140 if not _need_converter(dataType):
1141 return lambda x: x
1142
1143 if isinstance(dataType, ArrayType):
1144 conv = _create_converter(dataType.elementType)
1145 return lambda row: [conv(v) for v in row]
1146
1147 elif isinstance(dataType, MapType):
1148 kconv = _create_converter(dataType.keyType)
1149 vconv = _create_converter(dataType.valueType)
1150 return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
1151
1152 elif isinstance(dataType, NullType):
1153 return lambda x: None
1154
1155 elif not isinstance(dataType, StructType):
1156 return lambda x: x
1157
1158
1159 names = [f.name for f in dataType.fields]
1160 converters = [_create_converter(f.dataType) for f in dataType.fields]
1161 convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
1162
1163 def convert_struct(obj):
1164 if obj is None:
1165 return
1166
1167 if isinstance(obj, (tuple, list)):
1168 if convert_fields:
1169 return tuple(conv(v) for v, conv in zip(obj, converters))
1170 else:
1171 return tuple(obj)
1172
1173 if isinstance(obj, dict):
1174 d = obj
1175 elif hasattr(obj, "__dict__"):
1176 d = obj.__dict__
1177 else:
1178 raise TypeError("Unexpected obj type: %s" % type(obj))
1179
1180 if convert_fields:
1181 return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
1182 else:
1183 return tuple([d.get(name) for name in names])
1184
1185 return convert_struct
1186
1187
1188 _acceptable_types = {
1189 BooleanType: (bool,),
1190 ByteType: (int, long),
1191 ShortType: (int, long),
1192 IntegerType: (int, long),
1193 LongType: (int, long),
1194 FloatType: (float,),
1195 DoubleType: (float,),
1196 DecimalType: (decimal.Decimal,),
1197 StringType: (str, unicode),
1198 BinaryType: (bytearray, bytes),
1199 DateType: (datetime.date, datetime.datetime),
1200 TimestampType: (datetime.datetime,),
1201 ArrayType: (list, tuple, array),
1202 MapType: (dict,),
1203 StructType: (tuple, list, dict),
1204 }
1205
1206
1207 def _make_type_verifier(dataType, nullable=True, name=None):
1208 """
1209 Make a verifier that checks the type of obj against dataType and raises a TypeError if they do
1210 not match.
1211
1212 This verifier also checks the value of obj against datatype and raises a ValueError if it's not
1213 within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is
1214 not checked, so it will become infinity when cast to Java float, if it overflows.
1215
1216 >>> _make_type_verifier(StructType([]))(None)
1217 >>> _make_type_verifier(StringType())("")
1218 >>> _make_type_verifier(LongType())(0)
1219 >>> _make_type_verifier(LongType())(1 << 64) # doctest: +IGNORE_EXCEPTION_DETAIL
1220 Traceback (most recent call last):
1221 ...
1222 ValueError:...
1223 >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
1224 >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
1225 Traceback (most recent call last):
1226 ...
1227 TypeError:...
1228 >>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
1229 >>> _make_type_verifier(StructType([]))(())
1230 >>> _make_type_verifier(StructType([]))([])
1231 >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
1232 Traceback (most recent call last):
1233 ...
1234 ValueError:...
1235 >>> # Check if numeric values are within the allowed range.
1236 >>> _make_type_verifier(ByteType())(12)
1237 >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
1238 Traceback (most recent call last):
1239 ...
1240 ValueError:...
1241 >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
1242 Traceback (most recent call last):
1243 ...
1244 ValueError:...
1245 >>> _make_type_verifier(
1246 ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
1247 Traceback (most recent call last):
1248 ...
1249 ValueError:...
1250 >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
1251 Traceback (most recent call last):
1252 ...
1253 ValueError:...
1254 >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
1255 >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
1256 Traceback (most recent call last):
1257 ...
1258 ValueError:...
1259 """
1260
1261 if name is None:
1262 new_msg = lambda msg: msg
1263 new_name = lambda n: "field %s" % n
1264 else:
1265 new_msg = lambda msg: "%s: %s" % (name, msg)
1266 new_name = lambda n: "field %s in %s" % (n, name)
1267
1268 def verify_nullability(obj):
1269 if obj is None:
1270 if nullable:
1271 return True
1272 else:
1273 raise ValueError(new_msg("This field is not nullable, but got None"))
1274 else:
1275 return False
1276
1277 _type = type(dataType)
1278
1279 def assert_acceptable_types(obj):
1280 assert _type in _acceptable_types, \
1281 new_msg("unknown datatype: %s for object %r" % (dataType, obj))
1282
1283 def verify_acceptable_types(obj):
1284
1285 if type(obj) not in _acceptable_types[_type]:
1286 raise TypeError(new_msg("%s can not accept object %r in type %s"
1287 % (dataType, obj, type(obj))))
1288
1289 if isinstance(dataType, StringType):
1290
1291 verify_value = lambda _: _
1292
1293 elif isinstance(dataType, UserDefinedType):
1294 verifier = _make_type_verifier(dataType.sqlType(), name=name)
1295
1296 def verify_udf(obj):
1297 if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
1298 raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType)))
1299 verifier(dataType.toInternal(obj))
1300
1301 verify_value = verify_udf
1302
1303 elif isinstance(dataType, ByteType):
1304 def verify_byte(obj):
1305 assert_acceptable_types(obj)
1306 verify_acceptable_types(obj)
1307 if obj < -128 or obj > 127:
1308 raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj))
1309
1310 verify_value = verify_byte
1311
1312 elif isinstance(dataType, ShortType):
1313 def verify_short(obj):
1314 assert_acceptable_types(obj)
1315 verify_acceptable_types(obj)
1316 if obj < -32768 or obj > 32767:
1317 raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj))
1318
1319 verify_value = verify_short
1320
1321 elif isinstance(dataType, IntegerType):
1322 def verify_integer(obj):
1323 assert_acceptable_types(obj)
1324 verify_acceptable_types(obj)
1325 if obj < -2147483648 or obj > 2147483647:
1326 raise ValueError(
1327 new_msg("object of IntegerType out of range, got: %s" % obj))
1328
1329 verify_value = verify_integer
1330
1331 elif isinstance(dataType, LongType):
1332 def verify_long(obj):
1333 assert_acceptable_types(obj)
1334 verify_acceptable_types(obj)
1335 if obj < -9223372036854775808 or obj > 9223372036854775807:
1336 raise ValueError(
1337 new_msg("object of LongType out of range, got: %s" % obj))
1338
1339 verify_value = verify_long
1340
1341 elif isinstance(dataType, ArrayType):
1342 element_verifier = _make_type_verifier(
1343 dataType.elementType, dataType.containsNull, name="element in array %s" % name)
1344
1345 def verify_array(obj):
1346 assert_acceptable_types(obj)
1347 verify_acceptable_types(obj)
1348 for i in obj:
1349 element_verifier(i)
1350
1351 verify_value = verify_array
1352
1353 elif isinstance(dataType, MapType):
1354 key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name)
1355 value_verifier = _make_type_verifier(
1356 dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name)
1357
1358 def verify_map(obj):
1359 assert_acceptable_types(obj)
1360 verify_acceptable_types(obj)
1361 for k, v in obj.items():
1362 key_verifier(k)
1363 value_verifier(v)
1364
1365 verify_value = verify_map
1366
1367 elif isinstance(dataType, StructType):
1368 verifiers = []
1369 for f in dataType.fields:
1370 verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name))
1371 verifiers.append((f.name, verifier))
1372
1373 def verify_struct(obj):
1374 assert_acceptable_types(obj)
1375
1376 if isinstance(obj, dict):
1377 for f, verifier in verifiers:
1378 verifier(obj.get(f))
1379 elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
1380
1381 for f, verifier in verifiers:
1382 verifier(obj[f])
1383 elif isinstance(obj, (tuple, list)):
1384 if len(obj) != len(verifiers):
1385 raise ValueError(
1386 new_msg("Length of object (%d) does not match with "
1387 "length of fields (%d)" % (len(obj), len(verifiers))))
1388 for v, (_, verifier) in zip(obj, verifiers):
1389 verifier(v)
1390 elif hasattr(obj, "__dict__"):
1391 d = obj.__dict__
1392 for f, verifier in verifiers:
1393 verifier(d.get(f))
1394 else:
1395 raise TypeError(new_msg("StructType can not accept object %r in type %s"
1396 % (obj, type(obj))))
1397 verify_value = verify_struct
1398
1399 else:
1400 def verify_default(obj):
1401 assert_acceptable_types(obj)
1402 verify_acceptable_types(obj)
1403
1404 verify_value = verify_default
1405
1406 def verify(obj):
1407 if not verify_nullability(obj):
1408 verify_value(obj)
1409
1410 return verify
1411
1412
1413
1414 def _create_row_inbound_converter(dataType):
1415 return lambda *a: dataType.fromInternal(a)
1416
1417
1418 def _create_row(fields, values):
1419 row = Row(*values)
1420 row.__fields__ = fields
1421 return row
1422
1423
1424 class Row(tuple):
1425
1426 """
1427 A row in :class:`DataFrame`.
1428 The fields in it can be accessed:
1429
1430 * like attributes (``row.key``)
1431 * like dictionary values (``row[key]``)
1432
1433 ``key in row`` will search through row keys.
1434
1435 Row can be used to create a row object by using named arguments.
1436 It is not allowed to omit a named argument to represent that the value is
1437 None or missing. This should be explicitly set to None in this case.
1438
1439 NOTE: As of Spark 3.0.0, Rows created from named arguments no longer have
1440 field names sorted alphabetically and will be ordered in the position as
1441 entered. To enable sorting for Rows compatible with Spark 2.x, set the
1442 environment variable "PYSPARK_ROW_FIELD_SORTING_ENABLED" to "true". This
1443 option is deprecated and will be removed in future versions of Spark. For
1444 Python versions < 3.6, the order of named arguments is not guaranteed to
1445 be the same as entered, see https://www.python.org/dev/peps/pep-0468. In
1446 this case, a warning will be issued and the Row will fallback to sort the
1447 field names automatically.
1448
1449 NOTE: Examples with Row in pydocs are run with the environment variable
1450 "PYSPARK_ROW_FIELD_SORTING_ENABLED" set to "true" which results in output
1451 where fields are sorted.
1452
1453 >>> row = Row(name="Alice", age=11)
1454 >>> row
1455 Row(age=11, name='Alice')
1456 >>> row['name'], row['age']
1457 ('Alice', 11)
1458 >>> row.name, row.age
1459 ('Alice', 11)
1460 >>> 'name' in row
1461 True
1462 >>> 'wrong_key' in row
1463 False
1464
1465 Row also can be used to create another Row like class, then it
1466 could be used to create Row objects, such as
1467
1468 >>> Person = Row("name", "age")
1469 >>> Person
1470 <Row('name', 'age')>
1471 >>> 'name' in Person
1472 True
1473 >>> 'wrong_key' in Person
1474 False
1475 >>> Person("Alice", 11)
1476 Row(name='Alice', age=11)
1477
1478 This form can also be used to create rows as tuple values, i.e. with unnamed
1479 fields. Beware that such Row objects have different equality semantics:
1480
1481 >>> row1 = Row("Alice", 11)
1482 >>> row2 = Row(name="Alice", age=11)
1483 >>> row1 == row2
1484 False
1485 >>> row3 = Row(a="Alice", b=11)
1486 >>> row1 == row3
1487 True
1488 """
1489
1490
1491 _row_field_sorting_enabled = \
1492 os.environ.get('PYSPARK_ROW_FIELD_SORTING_ENABLED', 'false').lower() == 'true'
1493
1494 if _row_field_sorting_enabled:
1495 warnings.warn("The environment variable 'PYSPARK_ROW_FIELD_SORTING_ENABLED' "
1496 "is deprecated and will be removed in future versions of Spark")
1497
1498 def __new__(cls, *args, **kwargs):
1499 if args and kwargs:
1500 raise ValueError("Can not use both args "
1501 "and kwargs to create Row")
1502 if kwargs:
1503 if not Row._row_field_sorting_enabled and sys.version_info[:2] < (3, 6):
1504 warnings.warn("To use named arguments for Python version < 3.6, Row fields will be "
1505 "automatically sorted. This warning can be skipped by setting the "
1506 "environment variable 'PYSPARK_ROW_FIELD_SORTING_ENABLED' to 'true'.")
1507 Row._row_field_sorting_enabled = True
1508
1509
1510 if Row._row_field_sorting_enabled:
1511
1512 names = sorted(kwargs.keys())
1513 row = tuple.__new__(cls, [kwargs[n] for n in names])
1514 row.__fields__ = names
1515 row.__from_dict__ = True
1516 else:
1517 row = tuple.__new__(cls, list(kwargs.values()))
1518 row.__fields__ = list(kwargs.keys())
1519
1520 return row
1521 else:
1522
1523 return tuple.__new__(cls, args)
1524
1525 def asDict(self, recursive=False):
1526 """
1527 Return as a dict
1528
1529 :param recursive: turns the nested Rows to dict (default: False).
1530
1531 .. note:: If a row contains duplicate field names, e.g., the rows of a join
1532 between two :class:`DataFrame` that both have the fields of same names,
1533 one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
1534 will also return one of the duplicate fields, however returned value might
1535 be different to ``asDict``.
1536
1537 >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
1538 True
1539 >>> row = Row(key=1, value=Row(name='a', age=2))
1540 >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')}
1541 True
1542 >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
1543 True
1544 """
1545 if not hasattr(self, "__fields__"):
1546 raise TypeError("Cannot convert a Row class into dict")
1547
1548 if recursive:
1549 def conv(obj):
1550 if isinstance(obj, Row):
1551 return obj.asDict(True)
1552 elif isinstance(obj, list):
1553 return [conv(o) for o in obj]
1554 elif isinstance(obj, dict):
1555 return dict((k, conv(v)) for k, v in obj.items())
1556 else:
1557 return obj
1558 return dict(zip(self.__fields__, (conv(o) for o in self)))
1559 else:
1560 return dict(zip(self.__fields__, self))
1561
1562 def __contains__(self, item):
1563 if hasattr(self, "__fields__"):
1564 return item in self.__fields__
1565 else:
1566 return super(Row, self).__contains__(item)
1567
1568
1569 def __call__(self, *args):
1570 """create new Row object"""
1571 if len(args) > len(self):
1572 raise ValueError("Can not create Row with fields %s, expected %d values "
1573 "but got %s" % (self, len(self), args))
1574 return _create_row(self, args)
1575
1576 def __getitem__(self, item):
1577 if isinstance(item, (int, slice)):
1578 return super(Row, self).__getitem__(item)
1579 try:
1580
1581
1582 idx = self.__fields__.index(item)
1583 return super(Row, self).__getitem__(idx)
1584 except IndexError:
1585 raise KeyError(item)
1586 except ValueError:
1587 raise ValueError(item)
1588
1589 def __getattr__(self, item):
1590 if item.startswith("__"):
1591 raise AttributeError(item)
1592 try:
1593
1594
1595 idx = self.__fields__.index(item)
1596 return self[idx]
1597 except IndexError:
1598 raise AttributeError(item)
1599 except ValueError:
1600 raise AttributeError(item)
1601
1602 def __setattr__(self, key, value):
1603 if key != '__fields__' and key != "__from_dict__":
1604 raise Exception("Row is read-only")
1605 self.__dict__[key] = value
1606
1607 def __reduce__(self):
1608 """Returns a tuple so Python knows how to pickle Row."""
1609 if hasattr(self, "__fields__"):
1610 return (_create_row, (self.__fields__, tuple(self)))
1611 else:
1612 return tuple.__reduce__(self)
1613
1614 def __repr__(self):
1615 """Printable representation of Row used in Python REPL."""
1616 if hasattr(self, "__fields__"):
1617 return "Row(%s)" % ", ".join("%s=%r" % (k, v)
1618 for k, v in zip(self.__fields__, tuple(self)))
1619 else:
1620 return "<Row(%s)>" % ", ".join("%r" % field for field in self)
1621
1622
1623 class DateConverter(object):
1624 def can_convert(self, obj):
1625 return isinstance(obj, datetime.date)
1626
1627 def convert(self, obj, gateway_client):
1628 Date = JavaClass("java.sql.Date", gateway_client)
1629 return Date.valueOf(obj.strftime("%Y-%m-%d"))
1630
1631
1632 class DatetimeConverter(object):
1633 def can_convert(self, obj):
1634 return isinstance(obj, datetime.datetime)
1635
1636 def convert(self, obj, gateway_client):
1637 Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
1638 seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo
1639 else time.mktime(obj.timetuple()))
1640 t = Timestamp(int(seconds) * 1000)
1641 t.setNanos(obj.microsecond * 1000)
1642 return t
1643
1644
1645 register_input_converter(DatetimeConverter())
1646 register_input_converter(DateConverter())
1647
1648
1649 def _test():
1650 import doctest
1651 from pyspark.context import SparkContext
1652 from pyspark.sql import SparkSession
1653 globs = globals()
1654 sc = SparkContext('local[4]', 'PythonTest')
1655 globs['sc'] = sc
1656 globs['spark'] = SparkSession.builder.getOrCreate()
1657 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
1658 globs['sc'].stop()
1659 if failure_count:
1660 sys.exit(-1)
1661
1662
1663 if __name__ == "__main__":
1664 _test()