0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.io.Serializable;
0021 import java.sql.Timestamp;
0022 import java.text.SimpleDateFormat;
0023 import java.time.Instant;
0024 import java.time.LocalDate;
0025 import java.util.*;
0026
0027 import org.apache.commons.lang3.builder.ToStringBuilder;
0028 import org.apache.commons.lang3.builder.ToStringStyle;
0029 import org.junit.*;
0030
0031 import org.apache.spark.sql.*;
0032 import org.apache.spark.sql.catalyst.expressions.GenericRow;
0033 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
0034 import org.apache.spark.sql.catalyst.util.TimestampFormatter;
0035 import org.apache.spark.sql.internal.SQLConf;
0036 import org.apache.spark.sql.types.DataTypes;
0037 import org.apache.spark.sql.types.StructType;
0038
0039 import org.apache.spark.sql.test.TestSparkSession;
0040
0041 public class JavaBeanDeserializationSuite implements Serializable {
0042
0043 private TestSparkSession spark;
0044
0045 @Before
0046 public void setUp() {
0047 spark = new TestSparkSession();
0048 }
0049
0050 @After
0051 public void tearDown() {
0052 spark.stop();
0053 spark = null;
0054 }
0055
0056 private static final List<ArrayRecord> ARRAY_RECORDS = new ArrayList<>();
0057
0058 static {
0059 ARRAY_RECORDS.add(
0060 new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)),
0061 new int[] { 11, 12, 13, 14 })
0062 );
0063 ARRAY_RECORDS.add(
0064 new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)),
0065 new int[] { 21, 22, 23, 24 })
0066 );
0067 ARRAY_RECORDS.add(
0068 new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)),
0069 new int[] { 31, 32, 33, 34 })
0070 );
0071 }
0072
0073 @Test
0074 public void testBeanWithArrayFieldDeserialization() {
0075 Encoder<ArrayRecord> encoder = Encoders.bean(ArrayRecord.class);
0076
0077 Dataset<ArrayRecord> dataset = spark
0078 .read()
0079 .format("json")
0080 .schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>, " +
0081 "ints array<int>")
0082 .load("src/test/resources/test-data/with-array-fields.json")
0083 .as(encoder);
0084
0085 List<ArrayRecord> records = dataset.collectAsList();
0086 Assert.assertEquals(ARRAY_RECORDS, records);
0087 }
0088
0089 private static final List<MapRecord> MAP_RECORDS = new ArrayList<>();
0090
0091 static {
0092 MAP_RECORDS.add(new MapRecord(1,
0093 toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
0094 ));
0095 MAP_RECORDS.add(new MapRecord(2,
0096 toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
0097 ));
0098 MAP_RECORDS.add(new MapRecord(3,
0099 toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
0100 ));
0101 MAP_RECORDS.add(new MapRecord(4, new HashMap<>()));
0102 MAP_RECORDS.add(new MapRecord(5, null));
0103 }
0104
0105 private static <K, V> Map<K, V> toMap(Collection<K> keys, Collection<V> values) {
0106 Map<K, V> map = new HashMap<>();
0107 Iterator<K> keyI = keys.iterator();
0108 Iterator<V> valueI = values.iterator();
0109 while (keyI.hasNext() && valueI.hasNext()) {
0110 map.put(keyI.next(), valueI.next());
0111 }
0112 return map;
0113 }
0114
0115 @Test
0116 public void testBeanWithMapFieldsDeserialization() {
0117
0118 Encoder<MapRecord> encoder = Encoders.bean(MapRecord.class);
0119
0120 Dataset<MapRecord> dataset = spark
0121 .read()
0122 .format("json")
0123 .schema("id int, intervals map<string, struct<startTime: bigint, endTime: bigint>>")
0124 .load("src/test/resources/test-data/with-map-fields.json")
0125 .as(encoder);
0126
0127 List<MapRecord> records = dataset.collectAsList();
0128
0129 Assert.assertEquals(MAP_RECORDS, records);
0130 }
0131
0132 @Test
0133 public void testSpark22000() {
0134 List<Row> inputRows = new ArrayList<>();
0135 List<RecordSpark22000> expectedRecords = new ArrayList<>();
0136
0137 for (long idx = 0 ; idx < 5 ; idx++) {
0138 Row row = createRecordSpark22000Row(idx);
0139 inputRows.add(row);
0140 expectedRecords.add(createRecordSpark22000(row));
0141 }
0142
0143
0144
0145
0146
0147
0148 Encoder<RecordSpark22000> encoder = Encoders.bean(RecordSpark22000.class);
0149
0150 StructType schema = new StructType()
0151 .add("shortField", DataTypes.ShortType)
0152 .add("intField", DataTypes.IntegerType)
0153 .add("longField", DataTypes.LongType)
0154 .add("floatField", DataTypes.FloatType)
0155 .add("doubleField", DataTypes.DoubleType)
0156 .add("stringField", DataTypes.StringType)
0157 .add("booleanField", DataTypes.BooleanType)
0158 .add("timestampField", DataTypes.TimestampType)
0159
0160 .add("nullIntField", DataTypes.IntegerType, true);
0161
0162 Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
0163 Dataset<RecordSpark22000> dataset = dataFrame.as(encoder);
0164
0165 List<RecordSpark22000> records = dataset.collectAsList();
0166
0167 Assert.assertEquals(expectedRecords, records);
0168 }
0169
0170 @Test
0171 public void testSpark22000FailToUpcast() {
0172 List<Row> inputRows = new ArrayList<>();
0173 for (long idx = 0 ; idx < 5 ; idx++) {
0174 Row row = createRecordSpark22000FailToUpcastRow(idx);
0175 inputRows.add(row);
0176 }
0177
0178
0179 Encoder<RecordSpark22000FailToUpcast> encoder =
0180 Encoders.bean(RecordSpark22000FailToUpcast.class);
0181
0182 StructType schema = new StructType().add("id", DataTypes.StringType);
0183
0184 Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
0185
0186 try {
0187 dataFrame.as(encoder).collect();
0188 Assert.fail("Expected AnalysisException, but passed.");
0189 } catch (Throwable e) {
0190
0191
0192 if (e instanceof AnalysisException) {
0193 Assert.assertTrue(e.getMessage().contains("Cannot up cast "));
0194 } else {
0195 throw e;
0196 }
0197 }
0198 }
0199
0200 private static Row createRecordSpark22000Row(Long index) {
0201 Object[] values = new Object[] {
0202 index.shortValue(),
0203 index.intValue(),
0204 index,
0205 index.floatValue(),
0206 index.doubleValue(),
0207 String.valueOf(index),
0208 index % 2 == 0,
0209 new java.sql.Timestamp(System.currentTimeMillis()),
0210 null
0211 };
0212 return new GenericRow(values);
0213 }
0214
0215 private static String timestampToString(Timestamp ts) {
0216 String timestampString = String.valueOf(ts);
0217 String formatted = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(ts);
0218
0219 if (timestampString.length() > 19 && !timestampString.substring(19).equals(".0")) {
0220 return formatted + timestampString.substring(19);
0221 } else {
0222 return formatted;
0223 }
0224 }
0225
0226 private static RecordSpark22000 createRecordSpark22000(Row recordRow) {
0227 RecordSpark22000 record = new RecordSpark22000();
0228 record.setShortField(String.valueOf(recordRow.getShort(0)));
0229 record.setIntField(String.valueOf(recordRow.getInt(1)));
0230 record.setLongField(String.valueOf(recordRow.getLong(2)));
0231 record.setFloatField(String.valueOf(recordRow.getFloat(3)));
0232 record.setDoubleField(String.valueOf(recordRow.getDouble(4)));
0233 record.setStringField(recordRow.getString(5));
0234 record.setBooleanField(String.valueOf(recordRow.getBoolean(6)));
0235 record.setTimestampField(timestampToString(recordRow.getTimestamp(7)));
0236
0237 record.setNullIntField(null);
0238 return record;
0239 }
0240
0241 private static Row createRecordSpark22000FailToUpcastRow(Long index) {
0242 Object[] values = new Object[] { String.valueOf(index) };
0243 return new GenericRow(values);
0244 }
0245
0246 public static class ArrayRecord {
0247
0248 private int id;
0249 private List<Interval> intervals;
0250 private int[] ints;
0251
0252 public ArrayRecord() { }
0253
0254 ArrayRecord(int id, List<Interval> intervals, int[] ints) {
0255 this.id = id;
0256 this.intervals = intervals;
0257 this.ints = ints;
0258 }
0259
0260 public int getId() {
0261 return id;
0262 }
0263
0264 public void setId(int id) {
0265 this.id = id;
0266 }
0267
0268 public List<Interval> getIntervals() {
0269 return intervals;
0270 }
0271
0272 public void setIntervals(List<Interval> intervals) {
0273 this.intervals = intervals;
0274 }
0275
0276 public int[] getInts() {
0277 return ints;
0278 }
0279
0280 public void setInts(int[] ints) {
0281 this.ints = ints;
0282 }
0283
0284 @Override
0285 public int hashCode() {
0286 return id ^ Objects.hashCode(intervals) ^ Objects.hashCode(ints);
0287 }
0288
0289 @Override
0290 public boolean equals(Object obj) {
0291 if (!(obj instanceof ArrayRecord)) return false;
0292 ArrayRecord other = (ArrayRecord) obj;
0293 return (other.id == this.id) && Objects.equals(other.intervals, this.intervals) &&
0294 Arrays.equals(other.ints, ints);
0295 }
0296
0297 @Override
0298 public String toString() {
0299 return String.format("{ id: %d, intervals: %s, ints: %s }", id, intervals,
0300 Arrays.toString(ints));
0301 }
0302 }
0303
0304 public static class MapRecord {
0305
0306 private int id;
0307 private Map<String, Interval> intervals;
0308
0309 public MapRecord() { }
0310
0311 MapRecord(int id, Map<String, Interval> intervals) {
0312 this.id = id;
0313 this.intervals = intervals;
0314 }
0315
0316 public int getId() {
0317 return id;
0318 }
0319
0320 public void setId(int id) {
0321 this.id = id;
0322 }
0323
0324 public Map<String, Interval> getIntervals() {
0325 return intervals;
0326 }
0327
0328 public void setIntervals(Map<String, Interval> intervals) {
0329 this.intervals = intervals;
0330 }
0331
0332 @Override
0333 public int hashCode() {
0334 return id ^ Objects.hashCode(intervals);
0335 }
0336
0337 @Override
0338 public boolean equals(Object obj) {
0339 if (!(obj instanceof MapRecord)) return false;
0340 MapRecord other = (MapRecord) obj;
0341 return (other.id == this.id) && Objects.equals(other.intervals, this.intervals);
0342 }
0343
0344 @Override
0345 public String toString() {
0346 return String.format("{ id: %d, intervals: %s }", id, intervals);
0347 }
0348 }
0349
0350 public static class Interval {
0351
0352 private long startTime;
0353 private long endTime;
0354
0355 public Interval() { }
0356
0357 Interval(long startTime, long endTime) {
0358 this.startTime = startTime;
0359 this.endTime = endTime;
0360 }
0361
0362 public long getStartTime() {
0363 return startTime;
0364 }
0365
0366 public void setStartTime(long startTime) {
0367 this.startTime = startTime;
0368 }
0369
0370 public long getEndTime() {
0371 return endTime;
0372 }
0373
0374 public void setEndTime(long endTime) {
0375 this.endTime = endTime;
0376 }
0377
0378 @Override
0379 public int hashCode() {
0380 return Long.hashCode(startTime) ^ Long.hashCode(endTime);
0381 }
0382
0383 @Override
0384 public boolean equals(Object obj) {
0385 if (!(obj instanceof Interval)) return false;
0386 Interval other = (Interval) obj;
0387 return (other.startTime == this.startTime) && (other.endTime == this.endTime);
0388 }
0389
0390 @Override
0391 public String toString() {
0392 return String.format("[%d,%d]", startTime, endTime);
0393 }
0394 }
0395
0396 public static final class RecordSpark22000 {
0397 private String shortField;
0398 private String intField;
0399 private String longField;
0400 private String floatField;
0401 private String doubleField;
0402 private String stringField;
0403 private String booleanField;
0404 private String timestampField;
0405 private String nullIntField;
0406
0407 public RecordSpark22000() { }
0408
0409 public String getShortField() {
0410 return shortField;
0411 }
0412
0413 public void setShortField(String shortField) {
0414 this.shortField = shortField;
0415 }
0416
0417 public String getIntField() {
0418 return intField;
0419 }
0420
0421 public void setIntField(String intField) {
0422 this.intField = intField;
0423 }
0424
0425 public String getLongField() {
0426 return longField;
0427 }
0428
0429 public void setLongField(String longField) {
0430 this.longField = longField;
0431 }
0432
0433 public String getFloatField() {
0434 return floatField;
0435 }
0436
0437 public void setFloatField(String floatField) {
0438 this.floatField = floatField;
0439 }
0440
0441 public String getDoubleField() {
0442 return doubleField;
0443 }
0444
0445 public void setDoubleField(String doubleField) {
0446 this.doubleField = doubleField;
0447 }
0448
0449 public String getStringField() {
0450 return stringField;
0451 }
0452
0453 public void setStringField(String stringField) {
0454 this.stringField = stringField;
0455 }
0456
0457 public String getBooleanField() {
0458 return booleanField;
0459 }
0460
0461 public void setBooleanField(String booleanField) {
0462 this.booleanField = booleanField;
0463 }
0464
0465 public String getTimestampField() {
0466 return timestampField;
0467 }
0468
0469 public void setTimestampField(String timestampField) {
0470 this.timestampField = timestampField;
0471 }
0472
0473 public String getNullIntField() {
0474 return nullIntField;
0475 }
0476
0477 public void setNullIntField(String nullIntField) {
0478 this.nullIntField = nullIntField;
0479 }
0480
0481 @Override
0482 public boolean equals(Object o) {
0483 if (this == o) return true;
0484 if (o == null || getClass() != o.getClass()) return false;
0485 RecordSpark22000 that = (RecordSpark22000) o;
0486 return Objects.equals(shortField, that.shortField) &&
0487 Objects.equals(intField, that.intField) &&
0488 Objects.equals(longField, that.longField) &&
0489 Objects.equals(floatField, that.floatField) &&
0490 Objects.equals(doubleField, that.doubleField) &&
0491 Objects.equals(stringField, that.stringField) &&
0492 Objects.equals(booleanField, that.booleanField) &&
0493 Objects.equals(timestampField, that.timestampField) &&
0494 Objects.equals(nullIntField, that.nullIntField);
0495 }
0496
0497 @Override
0498 public int hashCode() {
0499 return Objects.hash(shortField, intField, longField, floatField, doubleField, stringField,
0500 booleanField, timestampField, nullIntField);
0501 }
0502
0503 @Override
0504 public String toString() {
0505 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
0506 .append("shortField", shortField)
0507 .append("intField", intField)
0508 .append("longField", longField)
0509 .append("floatField", floatField)
0510 .append("doubleField", doubleField)
0511 .append("stringField", stringField)
0512 .append("booleanField", booleanField)
0513 .append("timestampField", timestampField)
0514 .append("nullIntField", nullIntField)
0515 .toString();
0516 }
0517 }
0518
0519 public static final class RecordSpark22000FailToUpcast {
0520 private Integer id;
0521
0522 public RecordSpark22000FailToUpcast() {
0523 }
0524
0525 public Integer getId() {
0526 return id;
0527 }
0528
0529 public void setId(Integer id) {
0530 this.id = id;
0531 }
0532 }
0533
0534 @Test
0535 public void testBeanWithLocalDateAndInstant() {
0536 String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key());
0537 try {
0538 spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true");
0539 List<Row> inputRows = new ArrayList<>();
0540 List<LocalDateInstantRecord> expectedRecords = new ArrayList<>();
0541
0542 for (long idx = 0 ; idx < 5 ; idx++) {
0543 Row row = createLocalDateInstantRow(idx);
0544 inputRows.add(row);
0545 expectedRecords.add(createLocalDateInstantRecord(row));
0546 }
0547
0548 Encoder<LocalDateInstantRecord> encoder = Encoders.bean(LocalDateInstantRecord.class);
0549
0550 StructType schema = new StructType()
0551 .add("localDateField", DataTypes.DateType)
0552 .add("instantField", DataTypes.TimestampType);
0553
0554 Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
0555 Dataset<LocalDateInstantRecord> dataset = dataFrame.as(encoder);
0556
0557 List<LocalDateInstantRecord> records = dataset.collectAsList();
0558
0559 Assert.assertEquals(expectedRecords, records);
0560 } finally {
0561 spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf);
0562 }
0563 }
0564
0565 public static final class LocalDateInstantRecord {
0566 private String localDateField;
0567 private String instantField;
0568
0569 public LocalDateInstantRecord() { }
0570
0571 public String getLocalDateField() {
0572 return localDateField;
0573 }
0574
0575 public void setLocalDateField(String localDateField) {
0576 this.localDateField = localDateField;
0577 }
0578
0579 public String getInstantField() {
0580 return instantField;
0581 }
0582
0583 public void setInstantField(String instantField) {
0584 this.instantField = instantField;
0585 }
0586
0587 @Override
0588 public boolean equals(Object o) {
0589 if (this == o) return true;
0590 if (o == null || getClass() != o.getClass()) return false;
0591 LocalDateInstantRecord that = (LocalDateInstantRecord) o;
0592 return Objects.equals(localDateField, that.localDateField) &&
0593 Objects.equals(instantField, that.instantField);
0594 }
0595
0596 @Override
0597 public int hashCode() {
0598 return Objects.hash(localDateField, instantField);
0599 }
0600
0601 @Override
0602 public String toString() {
0603 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
0604 .append("localDateField", localDateField)
0605 .append("instantField", instantField)
0606 .toString();
0607 }
0608
0609 }
0610
0611 private static Row createLocalDateInstantRow(Long index) {
0612 Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) };
0613 return new GenericRow(values);
0614 }
0615
0616 private static LocalDateInstantRecord createLocalDateInstantRecord(Row recordRow) {
0617 LocalDateInstantRecord record = new LocalDateInstantRecord();
0618 record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0)));
0619 Instant instant = recordRow.getInstant(1);
0620 TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(
0621 DateTimeUtils.getZoneId(SQLConf.get().sessionLocalTimeZone()));
0622 record.setInstantField(formatter.format(DateTimeUtils.instantToMicros(instant)));
0623 return record;
0624 }
0625 }