0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.feature;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Test;
0024
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.ml.linalg.Vectors;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029
0030 public class JavaStandardScalerSuite extends SharedSparkSession {
0031
0032 @Test
0033 public void standardScaler() {
0034
0035 List<VectorIndexerSuite.FeatureData> points = Arrays.asList(
0036 new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
0037 new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
0038 new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
0039 );
0040 Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
0041 VectorIndexerSuite.FeatureData.class);
0042 StandardScaler scaler = new StandardScaler()
0043 .setInputCol("features")
0044 .setOutputCol("scaledFeatures")
0045 .setWithStd(true)
0046 .setWithMean(false);
0047
0048
0049 StandardScalerModel scalerModel = scaler.fit(dataFrame);
0050
0051
0052 Dataset<Row> scaledData = scalerModel.transform(dataFrame);
0053 scaledData.count();
0054 }
0055 }