Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.io.Serializable;
0021 import java.net.URISyntaxException;
0022 import java.net.URL;
0023 import java.util.*;
0024 import java.math.BigInteger;
0025 import java.math.BigDecimal;
0026 
0027 import scala.collection.JavaConverters;
0028 import scala.collection.Seq;
0029 
0030 import com.google.common.collect.ImmutableMap;
0031 import com.google.common.primitives.Ints;
0032 import org.junit.*;
0033 
0034 import org.apache.spark.api.java.JavaRDD;
0035 import org.apache.spark.api.java.JavaSparkContext;
0036 import org.apache.spark.sql.Dataset;
0037 import org.apache.spark.sql.Row;
0038 import org.apache.spark.sql.RowFactory;
0039 import org.apache.spark.sql.expressions.UserDefinedFunction;
0040 import org.apache.spark.sql.test.TestSparkSession;
0041 import org.apache.spark.sql.types.*;
0042 import org.apache.spark.util.sketch.BloomFilter;
0043 import org.apache.spark.util.sketch.CountMinSketch;
0044 import static org.apache.spark.sql.functions.*;
0045 import static org.apache.spark.sql.types.DataTypes.*;
0046 
0047 public class JavaDataFrameSuite {
0048   private transient TestSparkSession spark;
0049   private transient JavaSparkContext jsc;
0050 
0051   @Before
0052   public void setUp() {
0053     // Trigger static initializer of TestData
0054     spark = new TestSparkSession();
0055     jsc = new JavaSparkContext(spark.sparkContext());
0056     spark.loadTestData();
0057   }
0058 
0059   @After
0060   public void tearDown() {
0061     spark.stop();
0062     spark = null;
0063   }
0064 
0065   @Test
0066   public void testExecution() {
0067     Dataset<Row> df = spark.table("testData").filter("key = 1");
0068     Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));
0069   }
0070 
0071   @Test
0072   public void testCollectAndTake() {
0073     Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3");
0074     Assert.assertEquals(3, df.select("key").collectAsList().size());
0075     Assert.assertEquals(2, df.select("key").takeAsList(2).size());
0076   }
0077 
0078   /**
0079    * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
0080    */
0081   @Test
0082   public void testVarargMethods() {
0083     Dataset<Row> df = spark.table("testData");
0084 
0085     df.toDF("key1", "value1");
0086 
0087     df.select("key", "value");
0088     df.select(col("key"), col("value"));
0089     df.selectExpr("key", "value + 1");
0090 
0091     df.sort("key", "value");
0092     df.sort(col("key"), col("value"));
0093     df.orderBy("key", "value");
0094     df.orderBy(col("key"), col("value"));
0095 
0096     df.groupBy("key", "value").agg(col("key"), col("value"), sum("value"));
0097     df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value"));
0098     df.agg(first("key"), sum("value"));
0099 
0100     df.groupBy().avg("key");
0101     df.groupBy().mean("key");
0102     df.groupBy().max("key");
0103     df.groupBy().min("key");
0104     df.groupBy().sum("key");
0105 
0106     // Varargs in column expressions
0107     df.groupBy().agg(countDistinct("key", "value"));
0108     df.groupBy().agg(countDistinct(col("key"), col("value")));
0109     df.select(coalesce(col("key")));
0110 
0111     // Varargs with mathfunctions
0112     Dataset<Row> df2 = spark.table("testData2");
0113     df2.select(exp("a"), exp("b"));
0114     df2.select(exp(log("a")));
0115     df2.select(pow("a", "a"), pow("b", 2.0));
0116     df2.select(pow(col("a"), col("b")), exp("b"));
0117     df2.select(sin("a"), acos("b"));
0118 
0119     df2.select(rand(), acos("b"));
0120     df2.select(col("*"), randn(5L));
0121   }
0122 
0123   @Ignore
0124   public void testShow() {
0125     // This test case is intended ignored, but to make sure it compiles correctly
0126     Dataset<Row> df = spark.table("testData");
0127     df.show();
0128     df.show(1000);
0129   }
0130 
0131   public static class Bean implements Serializable {
0132     private double a = 0.0;
0133     private Integer[] b = { 0, 1 };
0134     private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
0135     private List<String> d = Arrays.asList("floppy", "disk");
0136     private BigInteger e = new BigInteger("1234567");
0137     private NestedBean f = new NestedBean();
0138     private NestedBean g = null;
0139 
0140     public double getA() {
0141       return a;
0142     }
0143 
0144     public Integer[] getB() {
0145       return b;
0146     }
0147 
0148     public Map<String, int[]> getC() {
0149       return c;
0150     }
0151 
0152     public List<String> getD() {
0153       return d;
0154     }
0155 
0156     public BigInteger getE() { return e; }
0157 
0158     public NestedBean getF() {
0159       return f;
0160     }
0161 
0162     public NestedBean getG() {
0163       return g;
0164     }
0165 
0166     public static class NestedBean implements Serializable {
0167       private int a = 1;
0168 
0169       public int getA() {
0170         return a;
0171       }
0172     }
0173   }
0174 
0175   void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
0176     StructType schema = df.schema();
0177     Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
0178       schema.apply("a"));
0179     Assert.assertEquals(
0180       new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
0181       schema.apply("b"));
0182     ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
0183     MapType mapType = new MapType(DataTypes.StringType, valueType, true);
0184     Assert.assertEquals(
0185       new StructField("c", mapType, true, Metadata.empty()),
0186       schema.apply("c"));
0187     Assert.assertEquals(
0188       new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
0189       schema.apply("d"));
0190     Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true,
0191       Metadata.empty()), schema.apply("e"));
0192     StructType nestedBeanType =
0193       DataTypes.createStructType(Collections.singletonList(new StructField(
0194         "a", IntegerType$.MODULE$, false, Metadata.empty())));
0195     Assert.assertEquals(new StructField("f", nestedBeanType, true, Metadata.empty()),
0196       schema.apply("f"));
0197     Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()),
0198       schema.apply("g"));
0199     Row first = df.select("a", "b", "c", "d", "e", "f", "g").first();
0200     Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
0201     // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
0202     // verify that it has the expected length, and contains expected elements.
0203     Seq<Integer> result = first.getAs(1);
0204     Assert.assertEquals(bean.getB().length, result.length());
0205     for (int i = 0; i < result.length(); i++) {
0206       Assert.assertEquals(bean.getB()[i], result.apply(i));
0207     }
0208     @SuppressWarnings("unchecked")
0209     Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello");
0210     Assert.assertArrayEquals(
0211       bean.getC().get("hello"),
0212       Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava()));
0213     Seq<String> d = first.getAs(3);
0214     Assert.assertEquals(bean.getD().size(), d.length());
0215     for (int i = 0; i < d.length(); i++) {
0216       Assert.assertEquals(bean.getD().get(i), d.apply(i));
0217     }
0218     // Java.math.BigInteger is equivalent to Spark Decimal(38,0)
0219     Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
0220     Row nested = first.getStruct(5);
0221     Assert.assertEquals(bean.getF().getA(), nested.getInt(0));
0222     Assert.assertTrue(first.isNullAt(6));
0223   }
0224 
0225   @Test
0226   public void testCreateDataFrameFromLocalJavaBeans() {
0227     Bean bean = new Bean();
0228     List<Bean> data = Arrays.asList(bean);
0229     Dataset<Row> df = spark.createDataFrame(data, Bean.class);
0230     validateDataFrameWithBeans(bean, df);
0231   }
0232 
0233   @Test
0234   public void testCreateDataFrameFromJavaBeans() {
0235     Bean bean = new Bean();
0236     JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
0237     Dataset<Row> df = spark.createDataFrame(rdd, Bean.class);
0238     validateDataFrameWithBeans(bean, df);
0239   }
0240 
0241   @Test
0242   public void testCreateDataFromFromList() {
0243     StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
0244     List<Row> rows = Arrays.asList(RowFactory.create(0));
0245     Dataset<Row> df = spark.createDataFrame(rows, schema);
0246     List<Row> result = df.collectAsList();
0247     Assert.assertEquals(1, result.size());
0248   }
0249 
0250   @Test
0251   public void testCreateStructTypeFromList(){
0252     List<StructField> fields1 = new ArrayList<>();
0253     fields1.add(new StructField("id", DataTypes.StringType, true, Metadata.empty()));
0254     StructType schema1 = StructType$.MODULE$.apply(fields1);
0255     Assert.assertEquals(0, schema1.fieldIndex("id"));
0256 
0257     List<StructField> fields2 =
0258         Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty()));
0259     StructType schema2 = StructType$.MODULE$.apply(fields2);
0260     Assert.assertEquals(0, schema2.fieldIndex("id"));
0261   }
0262 
0263   private static final Comparator<Row> crosstabRowComparator = (row1, row2) -> {
0264     String item1 = row1.getString(0);
0265     String item2 = row2.getString(0);
0266     return item1.compareTo(item2);
0267   };
0268 
0269   @Test
0270   public void testCrosstab() {
0271     Dataset<Row> df = spark.table("testData2");
0272     Dataset<Row> crosstab = df.stat().crosstab("a", "b");
0273     String[] columnNames = crosstab.schema().fieldNames();
0274     Assert.assertEquals("a_b", columnNames[0]);
0275     Assert.assertEquals("1", columnNames[1]);
0276     Assert.assertEquals("2", columnNames[2]);
0277     List<Row> rows = crosstab.collectAsList();
0278     rows.sort(crosstabRowComparator);
0279     Integer count = 1;
0280     for (Row row : rows) {
0281       Assert.assertEquals(row.get(0).toString(), count.toString());
0282       Assert.assertEquals(1L, row.getLong(1));
0283       Assert.assertEquals(1L, row.getLong(2));
0284       count++;
0285     }
0286   }
0287 
0288   @Test
0289   public void testFrequentItems() {
0290     Dataset<Row> df = spark.table("testData2");
0291     String[] cols = {"a"};
0292     Dataset<Row> results = df.stat().freqItems(cols, 0.2);
0293     Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1));
0294   }
0295 
0296   @Test
0297   public void testCorrelation() {
0298     Dataset<Row> df = spark.table("testData2");
0299     Double pearsonCorr = df.stat().corr("a", "b", "pearson");
0300     Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
0301   }
0302 
0303   @Test
0304   public void testCovariance() {
0305     Dataset<Row> df = spark.table("testData2");
0306     Double result = df.stat().cov("a", "b");
0307     Assert.assertTrue(Math.abs(result) < 1.0e-6);
0308   }
0309 
0310   @Test
0311   public void testSampleBy() {
0312     Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
0313     Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
0314     List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
0315     Assert.assertEquals(0, actual.get(0).getLong(0));
0316     Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8);
0317     Assert.assertEquals(1, actual.get(1).getLong(0));
0318     Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13);
0319   }
0320 
0321   @Test
0322   public void testSampleByColumn() {
0323     Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
0324     Dataset<Row> sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
0325     List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
0326     Assert.assertEquals(0, actual.get(0).getLong(0));
0327     Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8);
0328     Assert.assertEquals(1, actual.get(1).getLong(0));
0329     Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13);
0330   }
0331 
0332   @Test
0333   public void pivot() {
0334     Dataset<Row> df = spark.table("courseSales");
0335     List<Row> actual = df.groupBy("year")
0336       .pivot("course", Arrays.asList("dotNET", "Java"))
0337       .agg(sum("earnings")).orderBy("year").collectAsList();
0338 
0339     Assert.assertEquals(2012, actual.get(0).getInt(0));
0340     Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
0341     Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
0342 
0343     Assert.assertEquals(2013, actual.get(1).getInt(0));
0344     Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
0345     Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
0346   }
0347 
0348   @Test
0349   public void pivotColumnValues() {
0350     Dataset<Row> df = spark.table("courseSales");
0351     List<Row> actual = df.groupBy("year")
0352       .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
0353       .agg(sum("earnings")).orderBy("year").collectAsList();
0354 
0355     Assert.assertEquals(2012, actual.get(0).getInt(0));
0356     Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
0357     Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
0358 
0359     Assert.assertEquals(2013, actual.get(1).getInt(0));
0360     Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
0361     Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
0362   }
0363 
0364   private String getResource(String resource) {
0365     try {
0366       // The following "getResource" has different behaviors in SBT and Maven.
0367       // When running in Jenkins, the file path may contain "@" when there are multiple
0368       // SparkPullRequestBuilders running in the same worker
0369       // (e.g., /home/jenkins/workspace/SparkPullRequestBuilder@2)
0370       // When running in SBT, "@" in the file path will be returned as "@", however,
0371       // when running in Maven, "@" will be encoded as "%40".
0372       // Therefore, we convert it to URI then call "getPath" to decode it back so that it can both
0373       // work both in SBT and Maven.
0374       URL url = Thread.currentThread().getContextClassLoader().getResource(resource);
0375       return url.toURI().getPath();
0376     } catch (URISyntaxException e) {
0377       throw new RuntimeException(e);
0378     }
0379   }
0380 
0381   @Test
0382   public void testGenericLoad() {
0383     Dataset<Row> df1 = spark.read().format("text").load(getResource("test-data/text-suite.txt"));
0384     Assert.assertEquals(4L, df1.count());
0385 
0386     Dataset<Row> df2 = spark.read().format("text").load(
0387       getResource("test-data/text-suite.txt"),
0388       getResource("test-data/text-suite2.txt"));
0389     Assert.assertEquals(5L, df2.count());
0390   }
0391 
0392   @Test
0393   public void testTextLoad() {
0394     Dataset<String> ds1 = spark.read().textFile(getResource("test-data/text-suite.txt"));
0395     Assert.assertEquals(4L, ds1.count());
0396 
0397     Dataset<String> ds2 = spark.read().textFile(
0398       getResource("test-data/text-suite.txt"),
0399       getResource("test-data/text-suite2.txt"));
0400     Assert.assertEquals(5L, ds2.count());
0401   }
0402 
0403   @Test
0404   public void testCountMinSketch() {
0405     Dataset<Long> df = spark.range(1000);
0406 
0407     CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
0408     Assert.assertEquals(1000, sketch1.totalCount());
0409     Assert.assertEquals(10, sketch1.depth());
0410     Assert.assertEquals(20, sketch1.width());
0411 
0412     CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
0413     Assert.assertEquals(1000, sketch2.totalCount());
0414     Assert.assertEquals(10, sketch2.depth());
0415     Assert.assertEquals(20, sketch2.width());
0416 
0417     CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
0418     Assert.assertEquals(1000, sketch3.totalCount());
0419     Assert.assertEquals(0.001, sketch3.relativeError(), 1.0e-4);
0420     Assert.assertEquals(0.99, sketch3.confidence(), 5.0e-3);
0421 
0422     CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
0423     Assert.assertEquals(1000, sketch4.totalCount());
0424     Assert.assertEquals(0.001, sketch4.relativeError(), 1.0e-4);
0425     Assert.assertEquals(0.99, sketch4.confidence(), 5.0e-3);
0426   }
0427 
0428   @Test
0429   public void testBloomFilter() {
0430     Dataset<Long> df = spark.range(1000);
0431 
0432     BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
0433     Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
0434     for (int i = 0; i < 1000; i++) {
0435       Assert.assertTrue(filter1.mightContain(i));
0436     }
0437 
0438     BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03);
0439     Assert.assertTrue(filter2.expectedFpp() - 0.03 < 1e-3);
0440     for (int i = 0; i < 1000; i++) {
0441       Assert.assertTrue(filter2.mightContain(i * 3));
0442     }
0443 
0444     BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5);
0445     Assert.assertEquals(64 * 5, filter3.bitSize());
0446     for (int i = 0; i < 1000; i++) {
0447       Assert.assertTrue(filter3.mightContain(i));
0448     }
0449 
0450     BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5);
0451     Assert.assertEquals(64 * 5, filter4.bitSize());
0452     for (int i = 0; i < 1000; i++) {
0453       Assert.assertTrue(filter4.mightContain(i * 3));
0454     }
0455   }
0456 
0457   public static class BeanWithoutGetter implements Serializable {
0458     private String a;
0459 
0460     public void setA(String a) {
0461       this.a = a;
0462     }
0463   }
0464 
0465   @Test
0466   public void testBeanWithoutGetter() {
0467     BeanWithoutGetter bean = new BeanWithoutGetter();
0468     List<BeanWithoutGetter> data = Arrays.asList(bean);
0469     Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
0470     Assert.assertEquals(0, df.schema().length());
0471     Assert.assertEquals(1, df.collectAsList().size());
0472   }
0473 
0474   @SuppressWarnings("deprecation")
0475   @Test
0476   public void testJsonRDDToDataFrame() {
0477     // This is a test for the deprecated API in SPARK-15615.
0478     JavaRDD<String> rdd = jsc.parallelize(Arrays.asList("{\"a\": 2}"));
0479     Dataset<Row> df = spark.read().json(rdd);
0480     Assert.assertEquals(1L, df.count());
0481     Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0));
0482   }
0483 
0484   public class CircularReference1Bean implements Serializable {
0485     private CircularReference2Bean child;
0486 
0487     public CircularReference2Bean getChild() {
0488       return child;
0489     }
0490 
0491     public void setChild(CircularReference2Bean child) {
0492       this.child = child;
0493     }
0494   }
0495 
0496   public class CircularReference2Bean implements Serializable {
0497     private CircularReference1Bean child;
0498 
0499     public CircularReference1Bean getChild() {
0500       return child;
0501     }
0502 
0503     public void setChild(CircularReference1Bean child) {
0504       this.child = child;
0505     }
0506   }
0507 
0508   // Checks a simple case for DataFrame here and put exhaustive tests for the issue
0509   // of circular references in `JavaDatasetSuite`.
0510   @Test(expected = UnsupportedOperationException.class)
0511   public void testCircularReferenceBean() {
0512     CircularReference1Bean bean = new CircularReference1Bean();
0513     spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class);
0514   }
0515 
0516   @Test
0517   public void testUDF() {
0518     UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType);
0519     Dataset<Row> df = spark.table("testData").select(foo.apply(col("key"), col("value")));
0520     String[] result = df.collectAsList().stream().map(row -> row.getString(0))
0521       .toArray(String[]::new);
0522     String[] expected = spark.table("testData").collectAsList().stream()
0523       .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new);
0524     Assert.assertArrayEquals(expected, result);
0525   }
0526 }