0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.io.Serializable;
0021 import java.net.URISyntaxException;
0022 import java.net.URL;
0023 import java.util.*;
0024 import java.math.BigInteger;
0025 import java.math.BigDecimal;
0026
0027 import scala.collection.JavaConverters;
0028 import scala.collection.Seq;
0029
0030 import com.google.common.collect.ImmutableMap;
0031 import com.google.common.primitives.Ints;
0032 import org.junit.*;
0033
0034 import org.apache.spark.api.java.JavaRDD;
0035 import org.apache.spark.api.java.JavaSparkContext;
0036 import org.apache.spark.sql.Dataset;
0037 import org.apache.spark.sql.Row;
0038 import org.apache.spark.sql.RowFactory;
0039 import org.apache.spark.sql.expressions.UserDefinedFunction;
0040 import org.apache.spark.sql.test.TestSparkSession;
0041 import org.apache.spark.sql.types.*;
0042 import org.apache.spark.util.sketch.BloomFilter;
0043 import org.apache.spark.util.sketch.CountMinSketch;
0044 import static org.apache.spark.sql.functions.*;
0045 import static org.apache.spark.sql.types.DataTypes.*;
0046
0047 public class JavaDataFrameSuite {
0048 private transient TestSparkSession spark;
0049 private transient JavaSparkContext jsc;
0050
0051 @Before
0052 public void setUp() {
0053
0054 spark = new TestSparkSession();
0055 jsc = new JavaSparkContext(spark.sparkContext());
0056 spark.loadTestData();
0057 }
0058
0059 @After
0060 public void tearDown() {
0061 spark.stop();
0062 spark = null;
0063 }
0064
0065 @Test
0066 public void testExecution() {
0067 Dataset<Row> df = spark.table("testData").filter("key = 1");
0068 Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));
0069 }
0070
0071 @Test
0072 public void testCollectAndTake() {
0073 Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3");
0074 Assert.assertEquals(3, df.select("key").collectAsList().size());
0075 Assert.assertEquals(2, df.select("key").takeAsList(2).size());
0076 }
0077
0078
0079
0080
0081 @Test
0082 public void testVarargMethods() {
0083 Dataset<Row> df = spark.table("testData");
0084
0085 df.toDF("key1", "value1");
0086
0087 df.select("key", "value");
0088 df.select(col("key"), col("value"));
0089 df.selectExpr("key", "value + 1");
0090
0091 df.sort("key", "value");
0092 df.sort(col("key"), col("value"));
0093 df.orderBy("key", "value");
0094 df.orderBy(col("key"), col("value"));
0095
0096 df.groupBy("key", "value").agg(col("key"), col("value"), sum("value"));
0097 df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value"));
0098 df.agg(first("key"), sum("value"));
0099
0100 df.groupBy().avg("key");
0101 df.groupBy().mean("key");
0102 df.groupBy().max("key");
0103 df.groupBy().min("key");
0104 df.groupBy().sum("key");
0105
0106
0107 df.groupBy().agg(countDistinct("key", "value"));
0108 df.groupBy().agg(countDistinct(col("key"), col("value")));
0109 df.select(coalesce(col("key")));
0110
0111
0112 Dataset<Row> df2 = spark.table("testData2");
0113 df2.select(exp("a"), exp("b"));
0114 df2.select(exp(log("a")));
0115 df2.select(pow("a", "a"), pow("b", 2.0));
0116 df2.select(pow(col("a"), col("b")), exp("b"));
0117 df2.select(sin("a"), acos("b"));
0118
0119 df2.select(rand(), acos("b"));
0120 df2.select(col("*"), randn(5L));
0121 }
0122
0123 @Ignore
0124 public void testShow() {
0125
0126 Dataset<Row> df = spark.table("testData");
0127 df.show();
0128 df.show(1000);
0129 }
0130
0131 public static class Bean implements Serializable {
0132 private double a = 0.0;
0133 private Integer[] b = { 0, 1 };
0134 private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
0135 private List<String> d = Arrays.asList("floppy", "disk");
0136 private BigInteger e = new BigInteger("1234567");
0137 private NestedBean f = new NestedBean();
0138 private NestedBean g = null;
0139
0140 public double getA() {
0141 return a;
0142 }
0143
0144 public Integer[] getB() {
0145 return b;
0146 }
0147
0148 public Map<String, int[]> getC() {
0149 return c;
0150 }
0151
0152 public List<String> getD() {
0153 return d;
0154 }
0155
0156 public BigInteger getE() { return e; }
0157
0158 public NestedBean getF() {
0159 return f;
0160 }
0161
0162 public NestedBean getG() {
0163 return g;
0164 }
0165
0166 public static class NestedBean implements Serializable {
0167 private int a = 1;
0168
0169 public int getA() {
0170 return a;
0171 }
0172 }
0173 }
0174
0175 void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
0176 StructType schema = df.schema();
0177 Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
0178 schema.apply("a"));
0179 Assert.assertEquals(
0180 new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
0181 schema.apply("b"));
0182 ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
0183 MapType mapType = new MapType(DataTypes.StringType, valueType, true);
0184 Assert.assertEquals(
0185 new StructField("c", mapType, true, Metadata.empty()),
0186 schema.apply("c"));
0187 Assert.assertEquals(
0188 new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
0189 schema.apply("d"));
0190 Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true,
0191 Metadata.empty()), schema.apply("e"));
0192 StructType nestedBeanType =
0193 DataTypes.createStructType(Collections.singletonList(new StructField(
0194 "a", IntegerType$.MODULE$, false, Metadata.empty())));
0195 Assert.assertEquals(new StructField("f", nestedBeanType, true, Metadata.empty()),
0196 schema.apply("f"));
0197 Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()),
0198 schema.apply("g"));
0199 Row first = df.select("a", "b", "c", "d", "e", "f", "g").first();
0200 Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
0201
0202
0203 Seq<Integer> result = first.getAs(1);
0204 Assert.assertEquals(bean.getB().length, result.length());
0205 for (int i = 0; i < result.length(); i++) {
0206 Assert.assertEquals(bean.getB()[i], result.apply(i));
0207 }
0208 @SuppressWarnings("unchecked")
0209 Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello");
0210 Assert.assertArrayEquals(
0211 bean.getC().get("hello"),
0212 Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava()));
0213 Seq<String> d = first.getAs(3);
0214 Assert.assertEquals(bean.getD().size(), d.length());
0215 for (int i = 0; i < d.length(); i++) {
0216 Assert.assertEquals(bean.getD().get(i), d.apply(i));
0217 }
0218
0219 Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
0220 Row nested = first.getStruct(5);
0221 Assert.assertEquals(bean.getF().getA(), nested.getInt(0));
0222 Assert.assertTrue(first.isNullAt(6));
0223 }
0224
0225 @Test
0226 public void testCreateDataFrameFromLocalJavaBeans() {
0227 Bean bean = new Bean();
0228 List<Bean> data = Arrays.asList(bean);
0229 Dataset<Row> df = spark.createDataFrame(data, Bean.class);
0230 validateDataFrameWithBeans(bean, df);
0231 }
0232
0233 @Test
0234 public void testCreateDataFrameFromJavaBeans() {
0235 Bean bean = new Bean();
0236 JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
0237 Dataset<Row> df = spark.createDataFrame(rdd, Bean.class);
0238 validateDataFrameWithBeans(bean, df);
0239 }
0240
0241 @Test
0242 public void testCreateDataFromFromList() {
0243 StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
0244 List<Row> rows = Arrays.asList(RowFactory.create(0));
0245 Dataset<Row> df = spark.createDataFrame(rows, schema);
0246 List<Row> result = df.collectAsList();
0247 Assert.assertEquals(1, result.size());
0248 }
0249
0250 @Test
0251 public void testCreateStructTypeFromList(){
0252 List<StructField> fields1 = new ArrayList<>();
0253 fields1.add(new StructField("id", DataTypes.StringType, true, Metadata.empty()));
0254 StructType schema1 = StructType$.MODULE$.apply(fields1);
0255 Assert.assertEquals(0, schema1.fieldIndex("id"));
0256
0257 List<StructField> fields2 =
0258 Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty()));
0259 StructType schema2 = StructType$.MODULE$.apply(fields2);
0260 Assert.assertEquals(0, schema2.fieldIndex("id"));
0261 }
0262
0263 private static final Comparator<Row> crosstabRowComparator = (row1, row2) -> {
0264 String item1 = row1.getString(0);
0265 String item2 = row2.getString(0);
0266 return item1.compareTo(item2);
0267 };
0268
0269 @Test
0270 public void testCrosstab() {
0271 Dataset<Row> df = spark.table("testData2");
0272 Dataset<Row> crosstab = df.stat().crosstab("a", "b");
0273 String[] columnNames = crosstab.schema().fieldNames();
0274 Assert.assertEquals("a_b", columnNames[0]);
0275 Assert.assertEquals("1", columnNames[1]);
0276 Assert.assertEquals("2", columnNames[2]);
0277 List<Row> rows = crosstab.collectAsList();
0278 rows.sort(crosstabRowComparator);
0279 Integer count = 1;
0280 for (Row row : rows) {
0281 Assert.assertEquals(row.get(0).toString(), count.toString());
0282 Assert.assertEquals(1L, row.getLong(1));
0283 Assert.assertEquals(1L, row.getLong(2));
0284 count++;
0285 }
0286 }
0287
0288 @Test
0289 public void testFrequentItems() {
0290 Dataset<Row> df = spark.table("testData2");
0291 String[] cols = {"a"};
0292 Dataset<Row> results = df.stat().freqItems(cols, 0.2);
0293 Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1));
0294 }
0295
0296 @Test
0297 public void testCorrelation() {
0298 Dataset<Row> df = spark.table("testData2");
0299 Double pearsonCorr = df.stat().corr("a", "b", "pearson");
0300 Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
0301 }
0302
0303 @Test
0304 public void testCovariance() {
0305 Dataset<Row> df = spark.table("testData2");
0306 Double result = df.stat().cov("a", "b");
0307 Assert.assertTrue(Math.abs(result) < 1.0e-6);
0308 }
0309
0310 @Test
0311 public void testSampleBy() {
0312 Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
0313 Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
0314 List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
0315 Assert.assertEquals(0, actual.get(0).getLong(0));
0316 Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8);
0317 Assert.assertEquals(1, actual.get(1).getLong(0));
0318 Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13);
0319 }
0320
0321 @Test
0322 public void testSampleByColumn() {
0323 Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
0324 Dataset<Row> sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
0325 List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
0326 Assert.assertEquals(0, actual.get(0).getLong(0));
0327 Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8);
0328 Assert.assertEquals(1, actual.get(1).getLong(0));
0329 Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13);
0330 }
0331
0332 @Test
0333 public void pivot() {
0334 Dataset<Row> df = spark.table("courseSales");
0335 List<Row> actual = df.groupBy("year")
0336 .pivot("course", Arrays.asList("dotNET", "Java"))
0337 .agg(sum("earnings")).orderBy("year").collectAsList();
0338
0339 Assert.assertEquals(2012, actual.get(0).getInt(0));
0340 Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
0341 Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
0342
0343 Assert.assertEquals(2013, actual.get(1).getInt(0));
0344 Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
0345 Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
0346 }
0347
0348 @Test
0349 public void pivotColumnValues() {
0350 Dataset<Row> df = spark.table("courseSales");
0351 List<Row> actual = df.groupBy("year")
0352 .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
0353 .agg(sum("earnings")).orderBy("year").collectAsList();
0354
0355 Assert.assertEquals(2012, actual.get(0).getInt(0));
0356 Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
0357 Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
0358
0359 Assert.assertEquals(2013, actual.get(1).getInt(0));
0360 Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
0361 Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
0362 }
0363
0364 private String getResource(String resource) {
0365 try {
0366
0367
0368
0369
0370
0371
0372
0373
0374 URL url = Thread.currentThread().getContextClassLoader().getResource(resource);
0375 return url.toURI().getPath();
0376 } catch (URISyntaxException e) {
0377 throw new RuntimeException(e);
0378 }
0379 }
0380
0381 @Test
0382 public void testGenericLoad() {
0383 Dataset<Row> df1 = spark.read().format("text").load(getResource("test-data/text-suite.txt"));
0384 Assert.assertEquals(4L, df1.count());
0385
0386 Dataset<Row> df2 = spark.read().format("text").load(
0387 getResource("test-data/text-suite.txt"),
0388 getResource("test-data/text-suite2.txt"));
0389 Assert.assertEquals(5L, df2.count());
0390 }
0391
0392 @Test
0393 public void testTextLoad() {
0394 Dataset<String> ds1 = spark.read().textFile(getResource("test-data/text-suite.txt"));
0395 Assert.assertEquals(4L, ds1.count());
0396
0397 Dataset<String> ds2 = spark.read().textFile(
0398 getResource("test-data/text-suite.txt"),
0399 getResource("test-data/text-suite2.txt"));
0400 Assert.assertEquals(5L, ds2.count());
0401 }
0402
0403 @Test
0404 public void testCountMinSketch() {
0405 Dataset<Long> df = spark.range(1000);
0406
0407 CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
0408 Assert.assertEquals(1000, sketch1.totalCount());
0409 Assert.assertEquals(10, sketch1.depth());
0410 Assert.assertEquals(20, sketch1.width());
0411
0412 CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
0413 Assert.assertEquals(1000, sketch2.totalCount());
0414 Assert.assertEquals(10, sketch2.depth());
0415 Assert.assertEquals(20, sketch2.width());
0416
0417 CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
0418 Assert.assertEquals(1000, sketch3.totalCount());
0419 Assert.assertEquals(0.001, sketch3.relativeError(), 1.0e-4);
0420 Assert.assertEquals(0.99, sketch3.confidence(), 5.0e-3);
0421
0422 CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
0423 Assert.assertEquals(1000, sketch4.totalCount());
0424 Assert.assertEquals(0.001, sketch4.relativeError(), 1.0e-4);
0425 Assert.assertEquals(0.99, sketch4.confidence(), 5.0e-3);
0426 }
0427
0428 @Test
0429 public void testBloomFilter() {
0430 Dataset<Long> df = spark.range(1000);
0431
0432 BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
0433 Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
0434 for (int i = 0; i < 1000; i++) {
0435 Assert.assertTrue(filter1.mightContain(i));
0436 }
0437
0438 BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03);
0439 Assert.assertTrue(filter2.expectedFpp() - 0.03 < 1e-3);
0440 for (int i = 0; i < 1000; i++) {
0441 Assert.assertTrue(filter2.mightContain(i * 3));
0442 }
0443
0444 BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5);
0445 Assert.assertEquals(64 * 5, filter3.bitSize());
0446 for (int i = 0; i < 1000; i++) {
0447 Assert.assertTrue(filter3.mightContain(i));
0448 }
0449
0450 BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5);
0451 Assert.assertEquals(64 * 5, filter4.bitSize());
0452 for (int i = 0; i < 1000; i++) {
0453 Assert.assertTrue(filter4.mightContain(i * 3));
0454 }
0455 }
0456
0457 public static class BeanWithoutGetter implements Serializable {
0458 private String a;
0459
0460 public void setA(String a) {
0461 this.a = a;
0462 }
0463 }
0464
0465 @Test
0466 public void testBeanWithoutGetter() {
0467 BeanWithoutGetter bean = new BeanWithoutGetter();
0468 List<BeanWithoutGetter> data = Arrays.asList(bean);
0469 Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
0470 Assert.assertEquals(0, df.schema().length());
0471 Assert.assertEquals(1, df.collectAsList().size());
0472 }
0473
0474 @SuppressWarnings("deprecation")
0475 @Test
0476 public void testJsonRDDToDataFrame() {
0477
0478 JavaRDD<String> rdd = jsc.parallelize(Arrays.asList("{\"a\": 2}"));
0479 Dataset<Row> df = spark.read().json(rdd);
0480 Assert.assertEquals(1L, df.count());
0481 Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0));
0482 }
0483
0484 public class CircularReference1Bean implements Serializable {
0485 private CircularReference2Bean child;
0486
0487 public CircularReference2Bean getChild() {
0488 return child;
0489 }
0490
0491 public void setChild(CircularReference2Bean child) {
0492 this.child = child;
0493 }
0494 }
0495
0496 public class CircularReference2Bean implements Serializable {
0497 private CircularReference1Bean child;
0498
0499 public CircularReference1Bean getChild() {
0500 return child;
0501 }
0502
0503 public void setChild(CircularReference1Bean child) {
0504 this.child = child;
0505 }
0506 }
0507
0508
0509
0510 @Test(expected = UnsupportedOperationException.class)
0511 public void testCircularReferenceBean() {
0512 CircularReference1Bean bean = new CircularReference1Bean();
0513 spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class);
0514 }
0515
0516 @Test
0517 public void testUDF() {
0518 UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType);
0519 Dataset<Row> df = spark.table("testData").select(foo.apply(col("key"), col("value")));
0520 String[] result = df.collectAsList().stream().map(row -> row.getString(0))
0521 .toArray(String[]::new);
0522 String[] expected = spark.table("testData").collectAsList().stream()
0523 .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new);
0524 Assert.assertArrayEquals(expected, result);
0525 }
0526 }