0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.util.Arrays;
0021 import java.util.HashMap;
0022 import java.util.List;
0023 import static java.util.stream.Collectors.toList;
0024
0025 import static scala.collection.JavaConverters.mapAsScalaMap;
0026
0027 import org.junit.After;
0028 import org.junit.Assert;
0029 import org.junit.Before;
0030 import org.junit.Test;
0031
0032 import org.apache.spark.sql.Dataset;
0033 import org.apache.spark.sql.Row;
0034 import org.apache.spark.sql.RowFactory;
0035 import static org.apache.spark.sql.functions.*;
0036 import org.apache.spark.sql.test.TestSparkSession;
0037 import org.apache.spark.sql.types.*;
0038 import static org.apache.spark.sql.types.DataTypes.*;
0039
0040 public class JavaHigherOrderFunctionsSuite {
0041 private transient TestSparkSession spark;
0042 private Dataset<Row> arrDf;
0043 private Dataset<Row> mapDf;
0044
0045 private void checkAnswer(Dataset<Row> actualDS, List<Row> expected) throws Exception {
0046 List<Row> actual = actualDS.collectAsList();
0047 Assert.assertEquals(expected.size(), actual.size());
0048 for (int i = 0; i < expected.size(); i++) {
0049 Row expectedRow = expected.get(i);
0050 Row actualRow = actual.get(i);
0051 Assert.assertEquals(expectedRow.size(), actualRow.size());
0052 for (int j = 0; j < expectedRow.size(); j++) {
0053 Object expectedValue = expectedRow.get(j);
0054 Object actualValue = actualRow.get(j);
0055 if (expectedValue != null && expectedValue.getClass().isArray()) {
0056 actualValue = actualValue.getClass().getMethod("array").invoke(actualValue);
0057 Assert.assertArrayEquals((Object[]) expectedValue, (Object[]) actualValue);
0058 } else {
0059 Assert.assertEquals(expectedValue, actualValue);
0060 }
0061 }
0062 }
0063 }
0064
0065 @SafeVarargs
0066 private static <T> List<Row> toRows(T... objs) {
0067 return Arrays.stream(objs)
0068 .map(RowFactory::create)
0069 .collect(toList());
0070 }
0071
0072 @SafeVarargs
0073 private static <T> T[] makeArray(T... ts) {
0074 return ts;
0075 }
0076
0077 private void setUpArrDf() {
0078 List<Row> data = toRows(
0079 makeArray(1, 9, 8, 7),
0080 makeArray(5, 8, 9, 7, 2),
0081 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0082 null
0083 );
0084 StructType schema = new StructType()
0085 .add("x", new ArrayType(IntegerType, true), true);
0086 arrDf = spark.createDataFrame(data, schema);
0087 }
0088
0089 private void setUpMapDf() {
0090 List<Row> data = toRows(
0091 new HashMap<Integer, Integer>() {{
0092 put(1, 1);
0093 put(2, 2);
0094 }},
0095 null
0096 );
0097 StructType schema = new StructType()
0098 .add("x", new MapType(IntegerType, IntegerType, true));
0099 mapDf = spark.createDataFrame(data, schema);
0100 }
0101
0102 @Before
0103 public void setUp() {
0104 spark = new TestSparkSession();
0105 setUpArrDf();
0106 setUpMapDf();
0107 }
0108
0109 @After
0110 public void tearDown() {
0111 spark.stop();
0112 spark = null;
0113 }
0114
0115 @Test
0116 public void testTransform() throws Exception {
0117 checkAnswer(
0118 arrDf.select(transform(col("x"), x -> x.plus(1))),
0119 toRows(
0120 makeArray(2, 10, 9, 8),
0121 makeArray(6, 9, 10, 8, 3),
0122 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0123 null
0124 )
0125 );
0126 checkAnswer(
0127 arrDf.select(transform(col("x"), (x, i) -> x.plus(i))),
0128 toRows(
0129 makeArray(1, 10, 10, 10),
0130 makeArray(5, 9, 11, 10, 6),
0131 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0132 null
0133 )
0134 );
0135 }
0136
0137 @Test
0138 public void testFilter() throws Exception {
0139 checkAnswer(
0140 arrDf.select(filter(col("x"), x -> x.plus(1).equalTo(10))),
0141 toRows(
0142 makeArray(9),
0143 makeArray(9),
0144 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0145 null
0146 )
0147 );
0148 checkAnswer(
0149 arrDf.select(filter(col("x"), (x, i) -> x.plus(i).equalTo(10))),
0150 toRows(
0151 makeArray(9, 8, 7),
0152 makeArray(7),
0153 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0154 null
0155 )
0156 );
0157 }
0158
0159 @Test
0160 public void testExists() throws Exception {
0161 checkAnswer(
0162 arrDf.select(exists(col("x"), x -> x.plus(1).equalTo(10))),
0163 toRows(
0164 true,
0165 true,
0166 false,
0167 null
0168 )
0169 );
0170 }
0171
0172 @Test
0173 public void testForall() throws Exception {
0174 checkAnswer(
0175 arrDf.select(forall(col("x"), x -> x.plus(1).equalTo(10))),
0176 toRows(
0177 false,
0178 false,
0179 true,
0180 null
0181 )
0182 );
0183 }
0184
0185 @Test
0186 public void testAggregate() throws Exception {
0187 checkAnswer(
0188 arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))),
0189 toRows(
0190 25,
0191 31,
0192 0,
0193 null
0194 )
0195 );
0196 checkAnswer(
0197 arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)),
0198 toRows(
0199 25,
0200 31,
0201 0,
0202 null
0203 )
0204 );
0205 }
0206
0207 @Test
0208 public void testZipWith() throws Exception {
0209 checkAnswer(
0210 arrDf.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))),
0211 toRows(
0212 makeArray(42, 42, 42, 42),
0213 makeArray(42, 42, 42, 42, 42),
0214 JavaHigherOrderFunctionsSuite.<Integer>makeArray(),
0215 null
0216 )
0217 );
0218 }
0219
0220 @Test
0221 public void testTransformKeys() throws Exception {
0222 checkAnswer(
0223 mapDf.select(transform_keys(col("x"), (k, v) -> k.plus(v))),
0224 toRows(
0225 mapAsScalaMap(new HashMap<Integer, Integer>() {{
0226 put(2, 1);
0227 put(4, 2);
0228 }}),
0229 null
0230 )
0231 );
0232 }
0233
0234 @Test
0235 public void testTransformValues() throws Exception {
0236 checkAnswer(
0237 mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))),
0238 toRows(
0239 mapAsScalaMap(new HashMap<Integer, Integer>() {{
0240 put(1, 2);
0241 put(2, 4);
0242 }}),
0243 null
0244 )
0245 );
0246 }
0247
0248 @Test
0249 public void testMapFilter() throws Exception {
0250 checkAnswer(
0251 mapDf.select(map_filter(col("x"), (k, v) -> lit(false))),
0252 toRows(
0253 mapAsScalaMap(new HashMap<Integer, Integer>()),
0254 null
0255 )
0256 );
0257 }
0258
0259 @Test
0260 public void testMapZipWith() throws Exception {
0261 checkAnswer(
0262 mapDf.select(map_zip_with(col("x"), col("x"), (k, v1, v2) -> lit(false))),
0263 toRows(
0264 mapAsScalaMap(new HashMap<Integer, Boolean>() {{
0265 put(1, false);
0266 put(2, false);
0267 }}),
0268 null
0269 )
0270 );
0271 }
0272 }