0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.io.Serializable;
0021 import java.math.BigDecimal;
0022 import java.util.ArrayList;
0023 import java.util.Arrays;
0024 import java.util.List;
0025
0026 import org.junit.After;
0027 import org.junit.Assert;
0028 import org.junit.Before;
0029 import org.junit.Test;
0030
0031 import org.apache.spark.api.java.JavaRDD;
0032 import org.apache.spark.api.java.JavaSparkContext;
0033 import org.apache.spark.sql.Dataset;
0034 import org.apache.spark.sql.Encoders;
0035 import org.apache.spark.sql.Row;
0036 import org.apache.spark.sql.RowFactory;
0037 import org.apache.spark.sql.SparkSession;
0038 import org.apache.spark.sql.types.DataTypes;
0039 import org.apache.spark.sql.types.StructField;
0040 import org.apache.spark.sql.types.StructType;
0041
0042
0043
0044
0045 public class JavaApplySchemaSuite implements Serializable {
0046 private transient SparkSession spark;
0047 private transient JavaSparkContext jsc;
0048
0049 @Before
0050 public void setUp() {
0051 spark = SparkSession.builder()
0052 .master("local[*]")
0053 .appName("testing")
0054 .getOrCreate();
0055 jsc = new JavaSparkContext(spark.sparkContext());
0056 }
0057
0058 @After
0059 public void tearDown() {
0060 spark.stop();
0061 spark = null;
0062 }
0063
0064 public static class Person implements Serializable {
0065 private String name;
0066 private int age;
0067
0068 public String getName() {
0069 return name;
0070 }
0071
0072 public void setName(String name) {
0073 this.name = name;
0074 }
0075
0076 public int getAge() {
0077 return age;
0078 }
0079
0080 public void setAge(int age) {
0081 this.age = age;
0082 }
0083 }
0084
0085 @Test
0086 public void applySchema() {
0087 List<Person> personList = new ArrayList<>(2);
0088 Person person1 = new Person();
0089 person1.setName("Michael");
0090 person1.setAge(29);
0091 personList.add(person1);
0092 Person person2 = new Person();
0093 person2.setName("Yin");
0094 person2.setAge(28);
0095 personList.add(person2);
0096
0097 JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
0098 person -> RowFactory.create(person.getName(), person.getAge()));
0099
0100 List<StructField> fields = new ArrayList<>(2);
0101 fields.add(DataTypes.createStructField("name", DataTypes.StringType, false));
0102 fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
0103 StructType schema = DataTypes.createStructType(fields);
0104
0105 Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
0106 df.createOrReplaceTempView("people");
0107 List<Row> actual = spark.sql("SELECT * FROM people").collectAsList();
0108
0109 List<Row> expected = new ArrayList<>(2);
0110 expected.add(RowFactory.create("Michael", 29));
0111 expected.add(RowFactory.create("Yin", 28));
0112
0113 Assert.assertEquals(expected, actual);
0114 }
0115
0116 @Test
0117 public void dataFrameRDDOperations() {
0118 List<Person> personList = new ArrayList<>(2);
0119 Person person1 = new Person();
0120 person1.setName("Michael");
0121 person1.setAge(29);
0122 personList.add(person1);
0123 Person person2 = new Person();
0124 person2.setName("Yin");
0125 person2.setAge(28);
0126 personList.add(person2);
0127
0128 JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
0129 person -> RowFactory.create(person.getName(), person.getAge()));
0130
0131 List<StructField> fields = new ArrayList<>(2);
0132 fields.add(DataTypes.createStructField("", DataTypes.StringType, false));
0133 fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
0134 StructType schema = DataTypes.createStructType(fields);
0135
0136 Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
0137 df.createOrReplaceTempView("people");
0138 List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD()
0139 .map(row -> row.getString(0) + "_" + row.get(1)).collect();
0140
0141 List<String> expected = new ArrayList<>(2);
0142 expected.add("Michael_29");
0143 expected.add("Yin_28");
0144
0145 Assert.assertEquals(expected, actual);
0146 }
0147
0148 @Test
0149 public void applySchemaToJSON() {
0150 Dataset<String> jsonDS = spark.createDataset(Arrays.asList(
0151 "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " +
0152 "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " +
0153 "\"boolean\":true, \"null\":null}",
0154 "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " +
0155 "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
0156 "\"boolean\":false, \"null\":null}"), Encoders.STRING());
0157 List<StructField> fields = new ArrayList<>(7);
0158 fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0),
0159 true));
0160 fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
0161 fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
0162 fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true));
0163 fields.add(DataTypes.createStructField("long", DataTypes.LongType, true));
0164 fields.add(DataTypes.createStructField("null", DataTypes.StringType, true));
0165 fields.add(DataTypes.createStructField("string", DataTypes.StringType, true));
0166 StructType expectedSchema = DataTypes.createStructType(fields);
0167 List<Row> expectedResult = new ArrayList<>(2);
0168 expectedResult.add(
0169 RowFactory.create(
0170 new BigDecimal("92233720368547758070"),
0171 true,
0172 1.7976931348623157E308,
0173 10,
0174 21474836470L,
0175 null,
0176 "this is a simple string."));
0177 expectedResult.add(
0178 RowFactory.create(
0179 new BigDecimal("92233720368547758069"),
0180 false,
0181 1.7976931348623157E305,
0182 11,
0183 21474836469L,
0184 null,
0185 "this is another simple string."));
0186
0187 Dataset<Row> df1 = spark.read().json(jsonDS);
0188 StructType actualSchema1 = df1.schema();
0189 Assert.assertEquals(expectedSchema, actualSchema1);
0190 df1.createOrReplaceTempView("jsonTable1");
0191 List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList();
0192 Assert.assertEquals(expectedResult, actual1);
0193
0194 Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonDS);
0195 StructType actualSchema2 = df2.schema();
0196 Assert.assertEquals(expectedSchema, actualSchema2);
0197 df2.createOrReplaceTempView("jsonTable2");
0198 List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList();
0199 Assert.assertEquals(expectedResult, actual2);
0200 }
0201 }