Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.io.Serializable;
0021 import java.time.LocalDate;
0022 import java.util.List;
0023 
0024 import org.apache.spark.sql.internal.SQLConf;
0025 import org.junit.After;
0026 import org.junit.Assert;
0027 import org.junit.Before;
0028 import org.junit.Test;
0029 
0030 import org.apache.spark.sql.AnalysisException;
0031 import org.apache.spark.sql.Row;
0032 import org.apache.spark.sql.SparkSession;
0033 import org.apache.spark.sql.api.java.UDF2;
0034 import org.apache.spark.sql.types.DataTypes;
0035 
0036 // The test suite itself is Serializable so that anonymous Function implementations can be
0037 // serialized, as an alternative to converting these anonymous classes to static inner classes;
0038 // see http://stackoverflow.com/questions/758570/.
0039 public class JavaUDFSuite implements Serializable {
0040   private transient SparkSession spark;
0041 
0042   @Before
0043   public void setUp() {
0044     spark = SparkSession.builder()
0045       .master("local[*]")
0046       .appName("testing")
0047       .getOrCreate();
0048   }
0049 
0050   @After
0051   public void tearDown() {
0052     spark.stop();
0053     spark = null;
0054   }
0055 
0056   @SuppressWarnings("unchecked")
0057   @Test
0058   public void udf1Test() {
0059     spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType);
0060 
0061     Row result = spark.sql("SELECT stringLengthTest('test')").head();
0062     Assert.assertEquals(4, result.getInt(0));
0063   }
0064 
0065   @SuppressWarnings("unchecked")
0066   @Test
0067   public void udf2Test() {
0068     spark.udf().register("stringLengthTest",
0069         (String str1, String str2) -> str1.length() + str2.length(), DataTypes.IntegerType);
0070 
0071     Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
0072     Assert.assertEquals(9, result.getInt(0));
0073   }
0074 
0075   public static class StringLengthTest implements UDF2<String, String, Integer> {
0076     @Override
0077     public Integer call(String str1, String str2) {
0078       return str1.length() + str2.length();
0079     }
0080   }
0081 
0082   @SuppressWarnings("unchecked")
0083   @Test
0084   public void udf3Test() {
0085     spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
0086         DataTypes.IntegerType);
0087     Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
0088     Assert.assertEquals(9, result.getInt(0));
0089 
0090     // returnType is not provided
0091     spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
0092     result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
0093     Assert.assertEquals(9, result.getInt(0));
0094   }
0095 
0096   @SuppressWarnings("unchecked")
0097   @Test
0098   public void udf4Test() {
0099     spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
0100 
0101     spark.range(10).toDF("x").createOrReplaceTempView("tmp");
0102     // This tests when Java UDFs are required to be the semantically same (See SPARK-9435).
0103     List<Row> results = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList();
0104     Assert.assertEquals(10, results.size());
0105     long sum = 0;
0106     for (Row result : results) {
0107       sum += result.getLong(0);
0108     }
0109     Assert.assertEquals(55, sum);
0110   }
0111 
0112   @SuppressWarnings("unchecked")
0113   @Test(expected = AnalysisException.class)
0114   public void udf5Test() {
0115     spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
0116     List<Row> results = spark.sql("SELECT inc(1, 5)").collectAsList();
0117   }
0118 
0119   @SuppressWarnings("unchecked")
0120   @Test
0121   public void udf6Test() {
0122     spark.udf().register("returnOne", () -> 1, DataTypes.IntegerType);
0123     Row result = spark.sql("SELECT returnOne()").head();
0124     Assert.assertEquals(1, result.getInt(0));
0125   }
0126 
0127   @SuppressWarnings("unchecked")
0128   @Test
0129   public void udf7Test() {
0130     String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key());
0131     try {
0132       spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true");
0133       spark.udf().register(
0134           "plusDay",
0135           (java.time.LocalDate ld) -> ld.plusDays(1), DataTypes.DateType);
0136       Row result = spark.sql("SELECT plusDay(DATE '2019-02-26')").head();
0137       Assert.assertEquals(LocalDate.parse("2019-02-27"), result.get(0));
0138     } finally {
0139       spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf);
0140     }
0141   }
0142 }