Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.util.Arrays;
0021 
0022 import scala.Tuple2;
0023 
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026 
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Encoder;
0029 import org.apache.spark.sql.Encoders;
0030 import org.apache.spark.sql.KeyValueGroupedDataset;
0031 import org.apache.spark.sql.expressions.Aggregator;
0032 
0033 /**
0034  * Suite for testing the aggregate functionality of Datasets in Java.
0035  */
0036 public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
0037   @Test
0038   public void testTypedAggregationAnonClass() {
0039     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0040 
0041     Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
0042     Assert.assertEquals(
0043         Arrays.asList(new Tuple2<>("a", 3), new Tuple2<>("b", 3)),
0044         agged.collectAsList());
0045 
0046     Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
0047       .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
0048     Assert.assertEquals(
0049       Arrays.asList(
0050         new Tuple2<>("a", 3),
0051         new Tuple2<>("b", 3)),
0052       agged2.collectAsList());
0053   }
0054 
0055   static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
0056     @Override
0057     public Integer zero() {
0058       return 0;
0059     }
0060 
0061     @Override
0062     public Integer reduce(Integer l, Tuple2<String, Integer> t) {
0063       return l + t._2();
0064     }
0065 
0066     @Override
0067     public Integer merge(Integer b1, Integer b2) {
0068       return b1 + b2;
0069     }
0070 
0071     @Override
0072     public Integer finish(Integer reduction) {
0073       return reduction;
0074     }
0075 
0076     @Override
0077     public Encoder<Integer> bufferEncoder() {
0078       return Encoders.INT();
0079     }
0080 
0081     @Override
0082     public Encoder<Integer> outputEncoder() {
0083       return Encoders.INT();
0084     }
0085   }
0086 
0087   @SuppressWarnings("deprecation")
0088   @Test
0089   public void testTypedAggregationAverage() {
0090     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0091     Dataset<Tuple2<String, Double>> agged = grouped.agg(
0092       org.apache.spark.sql.expressions.javalang.typed.avg(value -> value._2() * 2.0));
0093     Assert.assertEquals(
0094         Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)),
0095         agged.collectAsList());
0096   }
0097 
0098   @SuppressWarnings("deprecation")
0099   @Test
0100   public void testTypedAggregationCount() {
0101     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0102     Dataset<Tuple2<String, Long>> agged = grouped.agg(
0103       org.apache.spark.sql.expressions.javalang.typed.count(value -> value));
0104     Assert.assertEquals(
0105         Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)),
0106         agged.collectAsList());
0107   }
0108 
0109   @SuppressWarnings("deprecation")
0110   @Test
0111   public void testTypedAggregationSumDouble() {
0112     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0113     Dataset<Tuple2<String, Double>> agged = grouped.agg(
0114       org.apache.spark.sql.expressions.javalang.typed.sum(value -> (double) value._2()));
0115     Assert.assertEquals(
0116         Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)),
0117         agged.collectAsList());
0118   }
0119 
0120   @SuppressWarnings("deprecation")
0121   @Test
0122   public void testTypedAggregationSumLong() {
0123     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
0124     Dataset<Tuple2<String, Long>> agged = grouped.agg(
0125       org.apache.spark.sql.expressions.javalang.typed.sumLong(value -> (long) value._2()));
0126     Assert.assertEquals(
0127         Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)),
0128         agged.collectAsList());
0129   }
0130 }