0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.feature;
0019
0020 import java.util.Arrays;
0021
0022 import org.junit.Test;
0023
0024 import org.apache.spark.SharedSparkSession;
0025 import org.apache.spark.api.java.JavaRDD;
0026 import org.apache.spark.ml.linalg.Vectors;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029
0030 public class JavaNormalizerSuite extends SharedSparkSession {
0031
0032 @Test
0033 public void normalizer() {
0034
0035 JavaRDD<VectorIndexerSuite.FeatureData> points = jsc.parallelize(Arrays.asList(
0036 new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
0037 new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
0038 new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
0039 ));
0040 Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
0041 Normalizer normalizer = new Normalizer()
0042 .setInputCol("features")
0043 .setOutputCol("normFeatures");
0044
0045
0046 Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
0047 l2NormData.count();
0048
0049
0050 Dataset<Row> lInfNormData =
0051 normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
0052 lInfNormData.count();
0053 }
0054 }