0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from pyspark import since, SparkContext
0019 from pyspark.sql.column import Column, _to_java_column
0020
0021
0022 @since("3.0.0")
0023 def vector_to_array(col, dtype="float64"):
0024 """
0025 Converts a column of MLlib sparse/dense vectors into a column of dense arrays.
0026
0027 :param col: A string of the column name or a Column
0028 :param dtype: The data type of the output array. Valid values: "float64" or "float32".
0029 :return: The converted column of dense arrays.
0030
0031 .. versionadded:: 3.0.0
0032
0033 >>> from pyspark.ml.linalg import Vectors
0034 >>> from pyspark.ml.functions import vector_to_array
0035 >>> from pyspark.mllib.linalg import Vectors as OldVectors
0036 >>> df = spark.createDataFrame([
0037 ... (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)),
0038 ... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]),
0039 ... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))],
0040 ... ["vec", "oldVec"])
0041 >>> df1 = df.select(vector_to_array("vec").alias("vec"),
0042 ... vector_to_array("oldVec").alias("oldVec"))
0043 >>> df1.collect()
0044 [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),
0045 Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]
0046 >>> df2 = df.select(vector_to_array("vec", "float32").alias("vec"),
0047 ... vector_to_array("oldVec", "float32").alias("oldVec"))
0048 >>> df2.collect()
0049 [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),
0050 Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]
0051 >>> df1.schema.fields
0052 [StructField(vec,ArrayType(DoubleType,false),false),
0053 StructField(oldVec,ArrayType(DoubleType,false),false)]
0054 >>> df2.schema.fields
0055 [StructField(vec,ArrayType(FloatType,false),false),
0056 StructField(oldVec,ArrayType(FloatType,false),false)]
0057 """
0058 sc = SparkContext._active_spark_context
0059 return Column(
0060 sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype))
0061
0062
0063 def _test():
0064 import doctest
0065 from pyspark.sql import SparkSession
0066 import pyspark.ml.functions
0067 import sys
0068 globs = pyspark.ml.functions.__dict__.copy()
0069 spark = SparkSession.builder \
0070 .master("local[2]") \
0071 .appName("ml.functions tests") \
0072 .getOrCreate()
0073 sc = spark.sparkContext
0074 globs['sc'] = sc
0075 globs['spark'] = spark
0076
0077 (failure_count, test_count) = doctest.testmod(
0078 pyspark.ml.functions, globs=globs,
0079 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0080 spark.stop()
0081 if failure_count:
0082 sys.exit(-1)
0083
0084
0085 if __name__ == "__main__":
0086 _test()