0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.util;
0019
0020 import java.util.Arrays;
0021 import java.util.Collections;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.mllib.linalg.*;
0028 import org.apache.spark.mllib.regression.LabeledPoint;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 import org.apache.spark.sql.RowFactory;
0032 import org.apache.spark.sql.types.DataTypes;
0033 import org.apache.spark.sql.types.Metadata;
0034 import org.apache.spark.sql.types.StructField;
0035 import org.apache.spark.sql.types.StructType;
0036
0037 public class JavaMLUtilsSuite extends SharedSparkSession {
0038
0039 @Test
0040 public void testConvertVectorColumnsToAndFromML() {
0041 Vector x = Vectors.dense(2.0);
0042 Dataset<Row> dataset = spark.createDataFrame(
0043 Collections.singletonList(new LabeledPoint(1.0, x)), LabeledPoint.class
0044 ).select("label", "features");
0045 Dataset<Row> newDataset1 = MLUtils.convertVectorColumnsToML(dataset);
0046 Row new1 = newDataset1.first();
0047 Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1);
0048 Row new2 = MLUtils.convertVectorColumnsToML(dataset, "features").first();
0049 Assert.assertEquals(new1, new2);
0050 Row old1 = MLUtils.convertVectorColumnsFromML(newDataset1).first();
0051 Assert.assertEquals(RowFactory.create(1.0, x), old1);
0052 }
0053
0054 @Test
0055 public void testConvertMatrixColumnsToAndFromML() {
0056 Matrix x = Matrices.dense(2, 1, new double[]{1.0, 2.0});
0057 StructType schema = new StructType(new StructField[]{
0058 new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
0059 new StructField("features", new MatrixUDT(), false, Metadata.empty())
0060 });
0061 Dataset<Row> dataset = spark.createDataFrame(
0062 Arrays.asList(
0063 RowFactory.create(1.0, x)),
0064 schema);
0065
0066 Dataset<Row> newDataset1 = MLUtils.convertMatrixColumnsToML(dataset);
0067 Row new1 = newDataset1.first();
0068 Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1);
0069 Row new2 = MLUtils.convertMatrixColumnsToML(dataset, "features").first();
0070 Assert.assertEquals(new1, new2);
0071 Row old1 = MLUtils.convertMatrixColumnsFromML(newDataset1).first();
0072 Assert.assertEquals(RowFactory.create(1.0, x), old1);
0073 }
0074 }