0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.util.Arrays;
0021
0022 import scala.Tuple2;
0023
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Encoder;
0029 import org.apache.spark.sql.Encoders;
0030 import org.apache.spark.sql.KeyValueGroupedDataset;
0031 import org.apache.spark.sql.expressions.Aggregator;
0032
0033
0034
0035
0036 public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
0037 @Test
0038 public void testTypedAggregationAnonClass() {
0039 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0040
0041 Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
0042 Assert.assertEquals(
0043 Arrays.asList(new Tuple2<>("a", 3), new Tuple2<>("b", 3)),
0044 agged.collectAsList());
0045
0046 Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
0047 .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
0048 Assert.assertEquals(
0049 Arrays.asList(
0050 new Tuple2<>("a", 3),
0051 new Tuple2<>("b", 3)),
0052 agged2.collectAsList());
0053 }
0054
0055 static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
0056 @Override
0057 public Integer zero() {
0058 return 0;
0059 }
0060
0061 @Override
0062 public Integer reduce(Integer l, Tuple2<String, Integer> t) {
0063 return l + t._2();
0064 }
0065
0066 @Override
0067 public Integer merge(Integer b1, Integer b2) {
0068 return b1 + b2;
0069 }
0070
0071 @Override
0072 public Integer finish(Integer reduction) {
0073 return reduction;
0074 }
0075
0076 @Override
0077 public Encoder<Integer> bufferEncoder() {
0078 return Encoders.INT();
0079 }
0080
0081 @Override
0082 public Encoder<Integer> outputEncoder() {
0083 return Encoders.INT();
0084 }
0085 }
0086
0087 @SuppressWarnings("deprecation")
0088 @Test
0089 public void testTypedAggregationAverage() {
0090 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0091 Dataset<Tuple2<String, Double>> agged = grouped.agg(
0092 org.apache.spark.sql.expressions.javalang.typed.avg(value -> value._2() * 2.0));
0093 Assert.assertEquals(
0094 Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)),
0095 agged.collectAsList());
0096 }
0097
0098 @SuppressWarnings("deprecation")
0099 @Test
0100 public void testTypedAggregationCount() {
0101 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0102 Dataset<Tuple2<String, Long>> agged = grouped.agg(
0103 org.apache.spark.sql.expressions.javalang.typed.count(value -> value));
0104 Assert.assertEquals(
0105 Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)),
0106 agged.collectAsList());
0107 }
0108
0109 @SuppressWarnings("deprecation")
0110 @Test
0111 public void testTypedAggregationSumDouble() {
0112 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0113 Dataset<Tuple2<String, Double>> agged = grouped.agg(
0114 org.apache.spark.sql.expressions.javalang.typed.sum(value -> (double) value._2()));
0115 Assert.assertEquals(
0116 Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)),
0117 agged.collectAsList());
0118 }
0119
0120 @SuppressWarnings("deprecation")
0121 @Test
0122 public void testTypedAggregationSumLong() {
0123 KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0124 Dataset<Tuple2<String, Long>> agged = grouped.agg(
0125 org.apache.spark.sql.expressions.javalang.typed.sumLong(value -> (long) value._2()));
0126 Assert.assertEquals(
0127 Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)),
0128 agged.collectAsList());
0129 }
0130 }