Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.io.Serializable;
0021 import java.math.BigDecimal;
0022 import java.util.ArrayList;
0023 import java.util.Arrays;
0024 import java.util.List;
0025 
0026 import org.junit.After;
0027 import org.junit.Assert;
0028 import org.junit.Before;
0029 import org.junit.Test;
0030 
0031 import org.apache.spark.api.java.JavaRDD;
0032 import org.apache.spark.api.java.JavaSparkContext;
0033 import org.apache.spark.sql.Dataset;
0034 import org.apache.spark.sql.Encoders;
0035 import org.apache.spark.sql.Row;
0036 import org.apache.spark.sql.RowFactory;
0037 import org.apache.spark.sql.SparkSession;
0038 import org.apache.spark.sql.types.DataTypes;
0039 import org.apache.spark.sql.types.StructField;
0040 import org.apache.spark.sql.types.StructType;
0041 
0042 // The test suite itself is Serializable so that anonymous Function implementations can be
0043 // serialized, as an alternative to converting these anonymous classes to static inner classes;
0044 // see http://stackoverflow.com/questions/758570/.
0045 public class JavaApplySchemaSuite implements Serializable {
0046   private transient SparkSession spark;
0047   private transient JavaSparkContext jsc;
0048 
0049   @Before
0050   public void setUp() {
0051     spark = SparkSession.builder()
0052       .master("local[*]")
0053       .appName("testing")
0054       .getOrCreate();
0055     jsc = new JavaSparkContext(spark.sparkContext());
0056   }
0057 
0058   @After
0059   public void tearDown() {
0060     spark.stop();
0061     spark = null;
0062   }
0063 
0064   public static class Person implements Serializable {
0065     private String name;
0066     private int age;
0067 
0068     public String getName() {
0069       return name;
0070     }
0071 
0072     public void setName(String name) {
0073       this.name = name;
0074     }
0075 
0076     public int getAge() {
0077       return age;
0078     }
0079 
0080     public void setAge(int age) {
0081       this.age = age;
0082     }
0083   }
0084 
0085   @Test
0086   public void applySchema() {
0087     List<Person> personList = new ArrayList<>(2);
0088     Person person1 = new Person();
0089     person1.setName("Michael");
0090     person1.setAge(29);
0091     personList.add(person1);
0092     Person person2 = new Person();
0093     person2.setName("Yin");
0094     person2.setAge(28);
0095     personList.add(person2);
0096 
0097     JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
0098         person -> RowFactory.create(person.getName(), person.getAge()));
0099 
0100     List<StructField> fields = new ArrayList<>(2);
0101     fields.add(DataTypes.createStructField("name", DataTypes.StringType, false));
0102     fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
0103     StructType schema = DataTypes.createStructType(fields);
0104 
0105     Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
0106     df.createOrReplaceTempView("people");
0107     List<Row> actual = spark.sql("SELECT * FROM people").collectAsList();
0108 
0109     List<Row> expected = new ArrayList<>(2);
0110     expected.add(RowFactory.create("Michael", 29));
0111     expected.add(RowFactory.create("Yin", 28));
0112 
0113     Assert.assertEquals(expected, actual);
0114   }
0115 
0116   @Test
0117   public void dataFrameRDDOperations() {
0118     List<Person> personList = new ArrayList<>(2);
0119     Person person1 = new Person();
0120     person1.setName("Michael");
0121     person1.setAge(29);
0122     personList.add(person1);
0123     Person person2 = new Person();
0124     person2.setName("Yin");
0125     person2.setAge(28);
0126     personList.add(person2);
0127 
0128     JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
0129         person -> RowFactory.create(person.getName(), person.getAge()));
0130 
0131     List<StructField> fields = new ArrayList<>(2);
0132     fields.add(DataTypes.createStructField("", DataTypes.StringType, false));
0133     fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
0134     StructType schema = DataTypes.createStructType(fields);
0135 
0136     Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
0137     df.createOrReplaceTempView("people");
0138     List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD()
0139       .map(row -> row.getString(0) + "_" + row.get(1)).collect();
0140 
0141     List<String> expected = new ArrayList<>(2);
0142     expected.add("Michael_29");
0143     expected.add("Yin_28");
0144 
0145     Assert.assertEquals(expected, actual);
0146   }
0147 
0148   @Test
0149   public void applySchemaToJSON() {
0150     Dataset<String> jsonDS = spark.createDataset(Arrays.asList(
0151       "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " +
0152         "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " +
0153         "\"boolean\":true, \"null\":null}",
0154       "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " +
0155         "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
0156         "\"boolean\":false, \"null\":null}"), Encoders.STRING());
0157     List<StructField> fields = new ArrayList<>(7);
0158     fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0),
0159       true));
0160     fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
0161     fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
0162     fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true));
0163     fields.add(DataTypes.createStructField("long", DataTypes.LongType, true));
0164     fields.add(DataTypes.createStructField("null", DataTypes.StringType, true));
0165     fields.add(DataTypes.createStructField("string", DataTypes.StringType, true));
0166     StructType expectedSchema = DataTypes.createStructType(fields);
0167     List<Row> expectedResult = new ArrayList<>(2);
0168     expectedResult.add(
0169       RowFactory.create(
0170         new BigDecimal("92233720368547758070"),
0171         true,
0172         1.7976931348623157E308,
0173         10,
0174         21474836470L,
0175         null,
0176         "this is a simple string."));
0177     expectedResult.add(
0178       RowFactory.create(
0179         new BigDecimal("92233720368547758069"),
0180         false,
0181         1.7976931348623157E305,
0182         11,
0183         21474836469L,
0184         null,
0185         "this is another simple string."));
0186 
0187     Dataset<Row> df1 = spark.read().json(jsonDS);
0188     StructType actualSchema1 = df1.schema();
0189     Assert.assertEquals(expectedSchema, actualSchema1);
0190     df1.createOrReplaceTempView("jsonTable1");
0191     List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList();
0192     Assert.assertEquals(expectedResult, actual1);
0193 
0194     Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonDS);
0195     StructType actualSchema2 = df2.schema();
0196     Assert.assertEquals(expectedSchema, actualSchema2);
0197     df2.createOrReplaceTempView("jsonTable2");
0198     List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList();
0199     Assert.assertEquals(expectedResult, actual2);
0200   }
0201 }