0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import print_function
0019
0020
0021 from pyspark.ml.linalg import Vectors
0022 from pyspark.ml.feature import (VectorSizeHint, VectorAssembler)
0023
0024 from pyspark.sql import SparkSession
0025
0026 if __name__ == "__main__":
0027 spark = SparkSession\
0028 .builder\
0029 .appName("VectorSizeHintExample")\
0030 .getOrCreate()
0031
0032
0033 dataset = spark.createDataFrame(
0034 [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0),
0035 (0, 18, 1.0, Vectors.dense([0.0, 10.0]), 0.0)],
0036 ["id", "hour", "mobile", "userFeatures", "clicked"])
0037
0038 sizeHint = VectorSizeHint(
0039 inputCol="userFeatures",
0040 handleInvalid="skip",
0041 size=3)
0042
0043 datasetWithSize = sizeHint.transform(dataset)
0044 print("Rows where 'userFeatures' is not the right size are filtered out")
0045 datasetWithSize.show(truncate=False)
0046
0047 assembler = VectorAssembler(
0048 inputCols=["hour", "mobile", "userFeatures"],
0049 outputCol="features")
0050
0051
0052 output = assembler.transform(datasetWithSize)
0053 print("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'")
0054 output.select("features", "clicked").show(truncate=False)
0055
0056
0057 spark.stop()