Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.util.Arrays;
0021 import java.util.HashMap;
0022 import java.util.List;
0023 import static java.util.stream.Collectors.toList;
0024 
0025 import static scala.collection.JavaConverters.mapAsScalaMap;
0026 
0027 import org.junit.After;
0028 import org.junit.Assert;
0029 import org.junit.Before;
0030 import org.junit.Test;
0031 
0032 import org.apache.spark.sql.Dataset;
0033 import org.apache.spark.sql.Row;
0034 import org.apache.spark.sql.RowFactory;
0035 import static org.apache.spark.sql.functions.*;
0036 import org.apache.spark.sql.test.TestSparkSession;
0037 import org.apache.spark.sql.types.*;
0038 import static org.apache.spark.sql.types.DataTypes.*;
0039 
0040 public class JavaHigherOrderFunctionsSuite {
0041     private transient TestSparkSession spark;
0042     private Dataset<Row> arrDf;
0043     private Dataset<Row> mapDf;
0044 
0045     private void checkAnswer(Dataset<Row> actualDS, List<Row> expected) throws Exception {
0046         List<Row> actual = actualDS.collectAsList();
0047         Assert.assertEquals(expected.size(), actual.size());
0048         for (int i = 0; i < expected.size(); i++) {
0049             Row expectedRow = expected.get(i);
0050             Row actualRow = actual.get(i);
0051             Assert.assertEquals(expectedRow.size(), actualRow.size());
0052             for (int j = 0; j < expectedRow.size(); j++) {
0053                 Object expectedValue = expectedRow.get(j);
0054                 Object actualValue = actualRow.get(j);
0055                 if (expectedValue != null && expectedValue.getClass().isArray()) {
0056                     actualValue = actualValue.getClass().getMethod("array").invoke(actualValue);
0057                     Assert.assertArrayEquals((Object[]) expectedValue, (Object[]) actualValue);
0058                 } else {
0059                     Assert.assertEquals(expectedValue, actualValue);
0060                 }
0061             }
0062         }
0063     }
0064 
0065     @SafeVarargs
0066     private static <T> List<Row> toRows(T... objs) {
0067         return Arrays.stream(objs)
0068             .map(RowFactory::create)
0069             .collect(toList());
0070     }
0071 
0072     @SafeVarargs
0073     private static <T> T[] makeArray(T... ts) {
0074         return ts;
0075     }
0076 
0077     private void setUpArrDf() {
0078         List<Row> data = toRows(
0079             makeArray(1, 9, 8, 7),
0080             makeArray(5, 8, 9, 7, 2),
0081             JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0082             null
0083         );
0084         StructType schema =  new StructType()
0085             .add("x", new ArrayType(IntegerType, true), true);
0086         arrDf = spark.createDataFrame(data, schema);
0087     }
0088 
0089     private void setUpMapDf() {
0090         List<Row> data = toRows(
0091             new HashMap<Integer, Integer>() {{
0092                 put(1, 1);
0093                 put(2, 2);
0094             }},
0095             null
0096         );
0097         StructType schema = new StructType()
0098             .add("x", new MapType(IntegerType, IntegerType, true));
0099         mapDf = spark.createDataFrame(data, schema);
0100     }
0101 
0102     @Before
0103     public void setUp() {
0104         spark = new TestSparkSession();
0105         setUpArrDf();
0106         setUpMapDf();
0107     }
0108 
0109     @After
0110     public void tearDown() {
0111         spark.stop();
0112         spark = null;
0113     }
0114 
0115     @Test
0116     public void testTransform() throws Exception {
0117         checkAnswer(
0118             arrDf.select(transform(col("x"), x -> x.plus(1))),
0119             toRows(
0120                 makeArray(2, 10, 9, 8),
0121                 makeArray(6, 9, 10, 8, 3),
0122                 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0123                 null
0124             )
0125         );
0126         checkAnswer(
0127             arrDf.select(transform(col("x"), (x, i) -> x.plus(i))),
0128             toRows(
0129                 makeArray(1, 10, 10, 10),
0130                 makeArray(5, 9, 11, 10, 6),
0131                 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0132                 null
0133             )
0134         );
0135     }
0136 
0137     @Test
0138     public void testFilter() throws Exception {
0139         checkAnswer(
0140             arrDf.select(filter(col("x"), x -> x.plus(1).equalTo(10))),
0141             toRows(
0142                 makeArray(9),
0143                 makeArray(9),
0144                 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0145                 null
0146             )
0147         );
0148         checkAnswer(
0149             arrDf.select(filter(col("x"), (x, i) -> x.plus(i).equalTo(10))),
0150             toRows(
0151                 makeArray(9, 8, 7),
0152                 makeArray(7),
0153                 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0154                 null
0155             )
0156         );
0157     }
0158 
0159     @Test
0160     public void testExists() throws Exception {
0161         checkAnswer(
0162             arrDf.select(exists(col("x"), x -> x.plus(1).equalTo(10))),
0163             toRows(
0164                 true,
0165                 true,
0166                 false,
0167                 null
0168             )
0169         );
0170     }
0171 
0172     @Test
0173     public void testForall() throws Exception {
0174         checkAnswer(
0175             arrDf.select(forall(col("x"), x -> x.plus(1).equalTo(10))),
0176             toRows(
0177                 false,
0178                 false,
0179                 true,
0180                 null
0181             )
0182         );
0183     }
0184 
0185     @Test
0186     public void testAggregate() throws Exception {
0187         checkAnswer(
0188             arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))),
0189             toRows(
0190                 25,
0191                 31,
0192                 0,
0193                 null
0194             )
0195         );
0196         checkAnswer(
0197             arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)),
0198             toRows(
0199                 25,
0200                 31,
0201                 0,
0202                 null
0203             )
0204         );
0205     }
0206 
0207     @Test
0208     public void testZipWith() throws Exception {
0209         checkAnswer(
0210             arrDf.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))),
0211             toRows(
0212                 makeArray(42, 42, 42, 42),
0213                 makeArray(42, 42, 42, 42, 42),
0214                 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0215                 null
0216             )
0217         );
0218     }
0219 
0220     @Test
0221     public void testTransformKeys() throws Exception {
0222         checkAnswer(
0223             mapDf.select(transform_keys(col("x"), (k, v) -> k.plus(v))),
0224             toRows(
0225                 mapAsScalaMap(new HashMap<Integer, Integer>() {{
0226                     put(2, 1);
0227                     put(4, 2);
0228                 }}),
0229                 null
0230             )
0231         );
0232     }
0233 
0234     @Test
0235     public void testTransformValues() throws Exception {
0236         checkAnswer(
0237             mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))),
0238             toRows(
0239                 mapAsScalaMap(new HashMap<Integer, Integer>() {{
0240                     put(1, 2);
0241                     put(2, 4);
0242                 }}),
0243                 null
0244             )
0245         );
0246     }
0247 
0248     @Test
0249     public void testMapFilter() throws Exception {
0250         checkAnswer(
0251             mapDf.select(map_filter(col("x"), (k, v) -> lit(false))),
0252             toRows(
0253                 mapAsScalaMap(new HashMap<Integer, Integer>()),
0254                 null
0255             )
0256         );
0257     }
0258 
0259     @Test
0260     public void testMapZipWith() throws Exception {
0261         checkAnswer(
0262             mapDf.select(map_zip_with(col("x"), col("x"), (k, v1, v2) -> lit(false))),
0263             toRows(
0264                 mapAsScalaMap(new HashMap<Integer, Boolean>() {{
0265                     put(1, false);
0266                     put(2, false);
0267                 }}),
0268                 null
0269             )
0270         );
0271     }
0272 }