0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.sql.hive;
0019
0020 import java.io.IOException;
0021 import java.util.ArrayList;
0022 import java.util.List;
0023
0024 import org.junit.After;
0025 import org.junit.Before;
0026 import org.junit.Test;
0027
0028 import org.apache.spark.sql.*;
0029 import org.apache.spark.sql.expressions.Window;
0030 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
0031 import static org.apache.spark.sql.functions.*;
0032 import org.apache.spark.sql.hive.test.TestHive$;
0033 import test.org.apache.spark.sql.MyDoubleSum;
0034
0035 public class JavaDataFrameSuite {
0036 private transient SQLContext hc;
0037
0038 Dataset<Row> df;
0039
0040 private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
0041 QueryTest$.MODULE$.checkAnswer(actual, expected);
0042 }
0043
0044 @Before
0045 public void setUp() throws IOException {
0046 hc = TestHive$.MODULE$;
0047 List<String> jsonObjects = new ArrayList<>(10);
0048 for (int i = 0; i < 10; i++) {
0049 jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}");
0050 }
0051 df = hc.read().json(hc.createDataset(jsonObjects, Encoders.STRING()));
0052 df.createOrReplaceTempView("window_table");
0053 }
0054
0055 @After
0056 public void tearDown() throws IOException {
0057
0058 if (hc != null) {
0059 hc.sql("DROP TABLE IF EXISTS window_table");
0060 }
0061 }
0062
0063 @Test
0064 public void saveTableAndQueryIt() {
0065 checkAnswer(
0066 df.select(avg("key").over(
0067 Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))),
0068 hc.sql("SELECT avg(key) " +
0069 "OVER (PARTITION BY value " +
0070 " ORDER BY key " +
0071 " ROWS BETWEEN 1 preceding and 1 following) " +
0072 "FROM window_table").collectAsList());
0073 }
0074
0075 @Test
0076 public void testUDAF() {
0077 Dataset<Row> df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value"));
0078 UserDefinedAggregateFunction udaf = new MyDoubleSum();
0079 UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf);
0080
0081
0082 Dataset<Row> aggregatedDF =
0083 df.groupBy()
0084 .agg(
0085 udaf.distinct(col("value")),
0086 udaf.apply(col("value")),
0087 registeredUDAF.apply(col("value")),
0088 callUDF("mydoublesum", col("value")));
0089
0090 List<Row> expectedResult = new ArrayList<>();
0091 expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0));
0092 checkAnswer(
0093 aggregatedDF,
0094 expectedResult);
0095 }
0096 }