0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.source.libsvm;
0019
0020 import java.io.File;
0021 import java.io.IOException;
0022 import java.nio.charset.StandardCharsets;
0023
0024 import com.google.common.io.Files;
0025
0026 import org.junit.Assert;
0027 import org.junit.Test;
0028
0029 import org.apache.spark.SharedSparkSession;
0030 import org.apache.spark.ml.linalg.DenseVector;
0031 import org.apache.spark.ml.linalg.Vectors;
0032 import org.apache.spark.sql.Dataset;
0033 import org.apache.spark.sql.Row;
0034 import org.apache.spark.util.Utils;
0035
0036
0037
0038
0039
0040 public class JavaLibSVMRelationSuite extends SharedSparkSession {
0041
0042 private File tempDir;
0043 private String path;
0044
0045 @Override
0046 public void setUp() throws IOException {
0047 super.setUp();
0048 tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
0049 File file = new File(tempDir, "part-00000");
0050 String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
0051 Files.write(s, file, StandardCharsets.UTF_8);
0052 path = tempDir.toURI().toString();
0053 }
0054
0055 @Override
0056 public void tearDown() {
0057 super.tearDown();
0058 Utils.deleteRecursively(tempDir);
0059 }
0060
0061 @Test
0062 public void verifyLibSVMDF() {
0063 Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense")
0064 .load(path);
0065 Assert.assertEquals("label", dataset.columns()[0]);
0066 Assert.assertEquals("features", dataset.columns()[1]);
0067 Row r = dataset.first();
0068 Assert.assertEquals(1.0, r.getDouble(0), 1e-15);
0069 DenseVector v = r.getAs(1);
0070 Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
0071 }
0072 }