0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.io.Serializable;
0021 import java.time.LocalDate;
0022 import java.util.List;
0023
0024 import org.apache.spark.sql.internal.SQLConf;
0025 import org.junit.After;
0026 import org.junit.Assert;
0027 import org.junit.Before;
0028 import org.junit.Test;
0029
0030 import org.apache.spark.sql.AnalysisException;
0031 import org.apache.spark.sql.Row;
0032 import org.apache.spark.sql.SparkSession;
0033 import org.apache.spark.sql.api.java.UDF2;
0034 import org.apache.spark.sql.types.DataTypes;
0035
0036
0037
0038
0039 public class JavaUDFSuite implements Serializable {
0040 private transient SparkSession spark;
0041
0042 @Before
0043 public void setUp() {
0044 spark = SparkSession.builder()
0045 .master("local[*]")
0046 .appName("testing")
0047 .getOrCreate();
0048 }
0049
0050 @After
0051 public void tearDown() {
0052 spark.stop();
0053 spark = null;
0054 }
0055
0056 @SuppressWarnings("unchecked")
0057 @Test
0058 public void udf1Test() {
0059 spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType);
0060
0061 Row result = spark.sql("SELECT stringLengthTest('test')").head();
0062 Assert.assertEquals(4, result.getInt(0));
0063 }
0064
0065 @SuppressWarnings("unchecked")
0066 @Test
0067 public void udf2Test() {
0068 spark.udf().register("stringLengthTest",
0069 (String str1, String str2) -> str1.length() + str2.length(), DataTypes.IntegerType);
0070
0071 Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
0072 Assert.assertEquals(9, result.getInt(0));
0073 }
0074
0075 public static class StringLengthTest implements UDF2<String, String, Integer> {
0076 @Override
0077 public Integer call(String str1, String str2) {
0078 return str1.length() + str2.length();
0079 }
0080 }
0081
0082 @SuppressWarnings("unchecked")
0083 @Test
0084 public void udf3Test() {
0085 spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
0086 DataTypes.IntegerType);
0087 Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
0088 Assert.assertEquals(9, result.getInt(0));
0089
0090
0091 spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
0092 result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
0093 Assert.assertEquals(9, result.getInt(0));
0094 }
0095
0096 @SuppressWarnings("unchecked")
0097 @Test
0098 public void udf4Test() {
0099 spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
0100
0101 spark.range(10).toDF("x").createOrReplaceTempView("tmp");
0102
0103 List<Row> results = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList();
0104 Assert.assertEquals(10, results.size());
0105 long sum = 0;
0106 for (Row result : results) {
0107 sum += result.getLong(0);
0108 }
0109 Assert.assertEquals(55, sum);
0110 }
0111
0112 @SuppressWarnings("unchecked")
0113 @Test(expected = AnalysisException.class)
0114 public void udf5Test() {
0115 spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
0116 List<Row> results = spark.sql("SELECT inc(1, 5)").collectAsList();
0117 }
0118
0119 @SuppressWarnings("unchecked")
0120 @Test
0121 public void udf6Test() {
0122 spark.udf().register("returnOne", () -> 1, DataTypes.IntegerType);
0123 Row result = spark.sql("SELECT returnOne()").head();
0124 Assert.assertEquals(1, result.getInt(0));
0125 }
0126
0127 @SuppressWarnings("unchecked")
0128 @Test
0129 public void udf7Test() {
0130 String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key());
0131 try {
0132 spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true");
0133 spark.udf().register(
0134 "plusDay",
0135 (java.time.LocalDate ld) -> ld.plusDays(1), DataTypes.DateType);
0136 Row result = spark.sql("SELECT plusDay(DATE '2019-02-26')").head();
0137 Assert.assertEquals(LocalDate.parse("2019-02-27"), result.get(0));
0138 } finally {
0139 spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf);
0140 }
0141 }
0142 }