0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import org.apache.spark.sql.Row;
0021 import org.apache.spark.sql.SparkSession;
0022 import org.junit.After;
0023 import org.junit.Assert;
0024 import org.junit.Before;
0025 import org.junit.Test;
0026
0027
0028 public class JavaUDAFSuite {
0029
0030 private transient SparkSession spark;
0031
0032 @Before
0033 public void setUp() {
0034 spark = SparkSession.builder()
0035 .master("local[*]")
0036 .appName("testing")
0037 .getOrCreate();
0038 }
0039
0040 @After
0041 public void tearDown() {
0042 spark.stop();
0043 spark = null;
0044 }
0045
0046 @SuppressWarnings("unchecked")
0047 @Test
0048 public void udf1Test() {
0049 spark.range(1, 10).toDF("value").createOrReplaceTempView("df");
0050 spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName());
0051 Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head();
0052 Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6);
0053 }
0054
0055 }