0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.feature;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.ml.linalg.Vector;
0028 import org.apache.spark.ml.linalg.VectorUDT;
0029 import org.apache.spark.ml.linalg.Vectors;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032 import org.apache.spark.sql.RowFactory;
0033 import org.apache.spark.sql.types.Metadata;
0034 import org.apache.spark.sql.types.StructField;
0035 import org.apache.spark.sql.types.StructType;
0036
0037 public class JavaPolynomialExpansionSuite extends SharedSparkSession {
0038
0039 @Test
0040 public void polynomialExpansionTest() {
0041 PolynomialExpansion polyExpansion = new PolynomialExpansion()
0042 .setInputCol("features")
0043 .setOutputCol("polyFeatures")
0044 .setDegree(3);
0045
0046 List<Row> data = Arrays.asList(
0047 RowFactory.create(
0048 Vectors.dense(-2.0, 2.3),
0049 Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
0050 ),
0051 RowFactory.create(Vectors.dense(0.0, 0.0), Vectors.dense(new double[9])),
0052 RowFactory.create(
0053 Vectors.dense(0.6, -1.1),
0054 Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331)
0055 )
0056 );
0057
0058 StructType schema = new StructType(new StructField[]{
0059 new StructField("features", new VectorUDT(), false, Metadata.empty()),
0060 new StructField("expected", new VectorUDT(), false, Metadata.empty())
0061 });
0062
0063 Dataset<Row> dataset = spark.createDataFrame(data, schema);
0064
0065 List<Row> pairs = polyExpansion.transform(dataset)
0066 .select("polyFeatures", "expected")
0067 .collectAsList();
0068
0069 for (Row r : pairs) {
0070 double[] polyFeatures = ((Vector) r.get(0)).toArray();
0071 double[] expected = ((Vector) r.get(1)).toArray();
0072 Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
0073 }
0074 }
0075 }