0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.util.Arrays;
0021
0022 import org.junit.Assert;
0023 import org.junit.Test;
0024 import scala.Tuple2;
0025
0026 import org.apache.spark.sql.Dataset;
0027 import org.apache.spark.sql.KeyValueGroupedDataset;
0028
0029
0030
0031
0032 public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
0033 @SuppressWarnings("deprecation")
0034 @Test
0035 public void testTypedAggregationAverage() {
0036 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0037 Dataset<Tuple2<String, Double>> agged = grouped.agg(
0038 org.apache.spark.sql.expressions.javalang.typed.avg(v -> (double)(v._2() * 2)));
0039 Assert.assertEquals(
0040 Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)),
0041 agged.collectAsList());
0042 }
0043
0044 @SuppressWarnings("deprecation")
0045 @Test
0046 public void testTypedAggregationCount() {
0047 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0048 Dataset<Tuple2<String, Long>> agged = grouped.agg(
0049 org.apache.spark.sql.expressions.javalang.typed.count(v -> v));
0050 Assert.assertEquals(
0051 Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)),
0052 agged.collectAsList());
0053 }
0054
0055 @SuppressWarnings("deprecation")
0056 @Test
0057 public void testTypedAggregationSumDouble() {
0058 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0059 Dataset<Tuple2<String, Double>> agged = grouped.agg(
0060 org.apache.spark.sql.expressions.javalang.typed.sum(v -> (double)v._2()));
0061 Assert.assertEquals(
0062 Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)),
0063 agged.collectAsList());
0064 }
0065
0066 @SuppressWarnings("deprecation")
0067 @Test
0068 public void testTypedAggregationSumLong() {
0069 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0070 Dataset<Tuple2<String, Long>> agged = grouped.agg(
0071 org.apache.spark.sql.expressions.javalang.typed.sumLong(v -> (long)v._2()));
0072 Assert.assertEquals(
0073 Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)),
0074 agged.collectAsList());
0075 }
0076 }