0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import org.apache.spark.api.java.function.FilterFunction;
0021 import org.apache.spark.sql.Column;
0022 import org.apache.spark.sql.Dataset;
0023 import org.apache.spark.sql.Row;
0024 import org.apache.spark.sql.RowFactory;
0025 import org.apache.spark.sql.test.TestSparkSession;
0026 import org.apache.spark.sql.types.StructType;
0027 import org.junit.After;
0028 import org.junit.Assert;
0029 import org.junit.Before;
0030 import org.junit.Test;
0031
0032 import java.util.*;
0033
0034 import static org.apache.spark.sql.types.DataTypes.*;
0035
0036 public class JavaColumnExpressionSuite {
0037 private transient TestSparkSession spark;
0038
0039 @Before
0040 public void setUp() {
0041 spark = new TestSparkSession();
0042 }
0043
0044 @After
0045 public void tearDown() {
0046 spark.stop();
0047 spark = null;
0048 }
0049
0050 @Test
0051 public void isInCollectionWorksCorrectlyOnJava() {
0052 List<Row> rows = Arrays.asList(
0053 RowFactory.create(1, "x"),
0054 RowFactory.create(2, "y"),
0055 RowFactory.create(3, "z"));
0056 StructType schema = createStructType(Arrays.asList(
0057 createStructField("a", IntegerType, false),
0058 createStructField("b", StringType, false)));
0059 Dataset<Row> df = spark.createDataFrame(rows, schema);
0060
0061 Assert.assertTrue(Arrays.equals(
0062 (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(),
0063 (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect()
0064 ));
0065 Assert.assertTrue(Arrays.equals(
0066 (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(),
0067 (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect()
0068 ));
0069 Assert.assertTrue(Arrays.equals(
0070 (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(),
0071 (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect()
0072 ));
0073 }
0074
0075 @Test
0076 public void isInCollectionCheckExceptionMessage() {
0077 List<Row> rows = Arrays.asList(
0078 RowFactory.create(1, Arrays.asList(1)),
0079 RowFactory.create(2, Arrays.asList(2)),
0080 RowFactory.create(3, Arrays.asList(3)));
0081 StructType schema = createStructType(Arrays.asList(
0082 createStructField("a", IntegerType, false),
0083 createStructField("b", createArrayType(IntegerType, false), false)));
0084 Dataset<Row> df = spark.createDataFrame(rows, schema);
0085 try {
0086 df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))));
0087 Assert.fail("Expected org.apache.spark.sql.AnalysisException");
0088 } catch (Exception e) {
0089 Arrays.asList("cannot resolve",
0090 "due to data type mismatch: Arguments must be same type but were")
0091 .forEach(s -> Assert.assertTrue(
0092 e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))));
0093 }
0094 }
0095 }