Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.sql.hive;
0019 
0020 import java.io.IOException;
0021 import java.util.ArrayList;
0022 import java.util.List;
0023 
0024 import org.junit.After;
0025 import org.junit.Before;
0026 import org.junit.Test;
0027 
0028 import org.apache.spark.sql.*;
0029 import org.apache.spark.sql.expressions.Window;
0030 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
0031 import static org.apache.spark.sql.functions.*;
0032 import org.apache.spark.sql.hive.test.TestHive$;
0033 import test.org.apache.spark.sql.MyDoubleSum;
0034 
0035 public class JavaDataFrameSuite {
0036   private transient SQLContext hc;
0037 
0038   Dataset<Row> df;
0039 
0040   private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
0041     QueryTest$.MODULE$.checkAnswer(actual, expected);
0042   }
0043 
0044   @Before
0045   public void setUp() throws IOException {
0046     hc = TestHive$.MODULE$;
0047     List<String> jsonObjects = new ArrayList<>(10);
0048     for (int i = 0; i < 10; i++) {
0049       jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}");
0050     }
0051     df = hc.read().json(hc.createDataset(jsonObjects, Encoders.STRING()));
0052     df.createOrReplaceTempView("window_table");
0053   }
0054 
0055   @After
0056   public void tearDown() throws IOException {
0057     // Clean up tables.
0058     if (hc != null) {
0059       hc.sql("DROP TABLE IF EXISTS window_table");
0060     }
0061   }
0062 
0063   @Test
0064   public void saveTableAndQueryIt() {
0065     checkAnswer(
0066       df.select(avg("key").over(
0067         Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))),
0068       hc.sql("SELECT avg(key) " +
0069         "OVER (PARTITION BY value " +
0070         "      ORDER BY key " +
0071         "      ROWS BETWEEN 1 preceding and 1 following) " +
0072         "FROM window_table").collectAsList());
0073   }
0074 
0075   @Test
0076   public void testUDAF() {
0077     Dataset<Row> df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value"));
0078     UserDefinedAggregateFunction udaf = new MyDoubleSum();
0079     UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf);
0080     // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if
0081     // we want to use distinct aggregation.
0082     Dataset<Row> aggregatedDF =
0083       df.groupBy()
0084         .agg(
0085           udaf.distinct(col("value")),
0086           udaf.apply(col("value")),
0087           registeredUDAF.apply(col("value")),
0088           callUDF("mydoublesum", col("value")));
0089 
0090     List<Row> expectedResult = new ArrayList<>();
0091     expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0));
0092     checkAnswer(
0093       aggregatedDF,
0094       expectedResult);
0095   }
0096 }