Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import org.apache.spark.api.java.function.FilterFunction;
0021 import org.apache.spark.sql.Column;
0022 import org.apache.spark.sql.Dataset;
0023 import org.apache.spark.sql.Row;
0024 import org.apache.spark.sql.RowFactory;
0025 import org.apache.spark.sql.test.TestSparkSession;
0026 import org.apache.spark.sql.types.StructType;
0027 import org.junit.After;
0028 import org.junit.Assert;
0029 import org.junit.Before;
0030 import org.junit.Test;
0031 
0032 import java.util.*;
0033 
0034 import static org.apache.spark.sql.types.DataTypes.*;
0035 
0036 public class JavaColumnExpressionSuite {
0037   private transient TestSparkSession spark;
0038 
0039   @Before
0040   public void setUp() {
0041     spark = new TestSparkSession();
0042   }
0043 
0044   @After
0045   public void tearDown() {
0046     spark.stop();
0047     spark = null;
0048   }
0049 
0050   @Test
0051   public void isInCollectionWorksCorrectlyOnJava() {
0052     List<Row> rows = Arrays.asList(
0053       RowFactory.create(1, "x"),
0054       RowFactory.create(2, "y"),
0055       RowFactory.create(3, "z"));
0056     StructType schema = createStructType(Arrays.asList(
0057       createStructField("a", IntegerType, false),
0058       createStructField("b", StringType, false)));
0059     Dataset<Row> df = spark.createDataFrame(rows, schema);
0060     // Test with different types of collections
0061     Assert.assertTrue(Arrays.equals(
0062       (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(),
0063       (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect()
0064     ));
0065     Assert.assertTrue(Arrays.equals(
0066       (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(),
0067       (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect()
0068     ));
0069     Assert.assertTrue(Arrays.equals(
0070       (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(),
0071       (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect()
0072     ));
0073   }
0074 
0075   @Test
0076   public void isInCollectionCheckExceptionMessage() {
0077     List<Row> rows = Arrays.asList(
0078       RowFactory.create(1, Arrays.asList(1)),
0079       RowFactory.create(2, Arrays.asList(2)),
0080       RowFactory.create(3, Arrays.asList(3)));
0081     StructType schema = createStructType(Arrays.asList(
0082       createStructField("a", IntegerType, false),
0083       createStructField("b", createArrayType(IntegerType, false), false)));
0084     Dataset<Row> df = spark.createDataFrame(rows, schema);
0085     try {
0086       df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))));
0087       Assert.fail("Expected org.apache.spark.sql.AnalysisException");
0088     } catch (Exception e) {
0089       Arrays.asList("cannot resolve",
0090         "due to data type mismatch: Arguments must be same type but were")
0091         .forEach(s -> Assert.assertTrue(
0092           e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))));
0093     }
0094   }
0095 }