Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.io.Serializable;
0021 import java.sql.Timestamp;
0022 import java.text.SimpleDateFormat;
0023 import java.time.Instant;
0024 import java.time.LocalDate;
0025 import java.util.*;
0026 
0027 import org.apache.commons.lang3.builder.ToStringBuilder;
0028 import org.apache.commons.lang3.builder.ToStringStyle;
0029 import org.junit.*;
0030 
0031 import org.apache.spark.sql.*;
0032 import org.apache.spark.sql.catalyst.expressions.GenericRow;
0033 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
0034 import org.apache.spark.sql.catalyst.util.TimestampFormatter;
0035 import org.apache.spark.sql.internal.SQLConf;
0036 import org.apache.spark.sql.types.DataTypes;
0037 import org.apache.spark.sql.types.StructType;
0038 
0039 import org.apache.spark.sql.test.TestSparkSession;
0040 
0041 public class JavaBeanDeserializationSuite implements Serializable {
0042 
0043   private TestSparkSession spark;
0044 
0045   @Before
0046   public void setUp() {
0047     spark = new TestSparkSession();
0048   }
0049 
0050   @After
0051   public void tearDown() {
0052     spark.stop();
0053     spark = null;
0054   }
0055 
0056   private static final List<ArrayRecord> ARRAY_RECORDS = new ArrayList<>();
0057 
0058   static {
0059     ARRAY_RECORDS.add(
0060       new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)),
0061               new int[] { 11, 12, 13, 14 })
0062     );
0063     ARRAY_RECORDS.add(
0064       new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)),
0065               new int[] { 21, 22, 23, 24 })
0066     );
0067     ARRAY_RECORDS.add(
0068       new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)),
0069               new int[] { 31, 32, 33, 34 })
0070     );
0071   }
0072 
0073   @Test
0074   public void testBeanWithArrayFieldDeserialization() {
0075     Encoder<ArrayRecord> encoder = Encoders.bean(ArrayRecord.class);
0076 
0077     Dataset<ArrayRecord> dataset = spark
0078       .read()
0079       .format("json")
0080       .schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>, " +
0081           "ints array<int>")
0082       .load("src/test/resources/test-data/with-array-fields.json")
0083       .as(encoder);
0084 
0085     List<ArrayRecord> records = dataset.collectAsList();
0086     Assert.assertEquals(ARRAY_RECORDS, records);
0087   }
0088 
0089   private static final List<MapRecord> MAP_RECORDS = new ArrayList<>();
0090 
0091   static {
0092     MAP_RECORDS.add(new MapRecord(1,
0093       toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
0094     ));
0095     MAP_RECORDS.add(new MapRecord(2,
0096       toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
0097     ));
0098     MAP_RECORDS.add(new MapRecord(3,
0099       toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
0100     ));
0101     MAP_RECORDS.add(new MapRecord(4, new HashMap<>()));
0102     MAP_RECORDS.add(new MapRecord(5, null));
0103   }
0104 
0105   private static <K, V> Map<K, V> toMap(Collection<K> keys, Collection<V> values) {
0106     Map<K, V> map = new HashMap<>();
0107     Iterator<K> keyI = keys.iterator();
0108     Iterator<V> valueI = values.iterator();
0109     while (keyI.hasNext() && valueI.hasNext()) {
0110       map.put(keyI.next(), valueI.next());
0111     }
0112     return map;
0113   }
0114 
0115   @Test
0116   public void testBeanWithMapFieldsDeserialization() {
0117 
0118     Encoder<MapRecord> encoder = Encoders.bean(MapRecord.class);
0119 
0120     Dataset<MapRecord> dataset = spark
0121       .read()
0122       .format("json")
0123       .schema("id int, intervals map<string, struct<startTime: bigint, endTime: bigint>>")
0124       .load("src/test/resources/test-data/with-map-fields.json")
0125       .as(encoder);
0126 
0127     List<MapRecord> records = dataset.collectAsList();
0128 
0129     Assert.assertEquals(MAP_RECORDS, records);
0130   }
0131 
0132   @Test
0133   public void testSpark22000() {
0134     List<Row> inputRows = new ArrayList<>();
0135     List<RecordSpark22000> expectedRecords = new ArrayList<>();
0136 
0137     for (long idx = 0 ; idx < 5 ; idx++) {
0138       Row row = createRecordSpark22000Row(idx);
0139       inputRows.add(row);
0140       expectedRecords.add(createRecordSpark22000(row));
0141     }
0142 
0143     // Here we try to convert the fields, from any types to string.
0144     // Before applying SPARK-22000, Spark called toString() against variable which type might
0145     // be primitive.
0146     // SPARK-22000 it calls String.valueOf() which finally calls toString() but handles boxing
0147     // if the type is primitive.
0148     Encoder<RecordSpark22000> encoder = Encoders.bean(RecordSpark22000.class);
0149 
0150     StructType schema = new StructType()
0151       .add("shortField", DataTypes.ShortType)
0152       .add("intField", DataTypes.IntegerType)
0153       .add("longField", DataTypes.LongType)
0154       .add("floatField", DataTypes.FloatType)
0155       .add("doubleField", DataTypes.DoubleType)
0156       .add("stringField", DataTypes.StringType)
0157       .add("booleanField", DataTypes.BooleanType)
0158       .add("timestampField", DataTypes.TimestampType)
0159       // explicitly setting nullable = true to make clear the intention
0160       .add("nullIntField", DataTypes.IntegerType, true);
0161 
0162     Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
0163     Dataset<RecordSpark22000> dataset = dataFrame.as(encoder);
0164 
0165     List<RecordSpark22000> records = dataset.collectAsList();
0166 
0167     Assert.assertEquals(expectedRecords, records);
0168   }
0169 
0170   @Test
0171   public void testSpark22000FailToUpcast() {
0172     List<Row> inputRows = new ArrayList<>();
0173     for (long idx = 0 ; idx < 5 ; idx++) {
0174       Row row = createRecordSpark22000FailToUpcastRow(idx);
0175       inputRows.add(row);
0176     }
0177 
0178     // Here we try to convert the fields, from string type to int, which upcast doesn't help.
0179     Encoder<RecordSpark22000FailToUpcast> encoder =
0180             Encoders.bean(RecordSpark22000FailToUpcast.class);
0181 
0182     StructType schema = new StructType().add("id", DataTypes.StringType);
0183 
0184     Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
0185 
0186     try {
0187       dataFrame.as(encoder).collect();
0188       Assert.fail("Expected AnalysisException, but passed.");
0189     } catch (Throwable e) {
0190       // Here we need to handle weird case: compiler complains AnalysisException never be thrown
0191       // in try statement, but it can be thrown actually. Maybe Scala-Java interop issue?
0192       if (e instanceof AnalysisException) {
0193         Assert.assertTrue(e.getMessage().contains("Cannot up cast "));
0194       } else {
0195         throw e;
0196       }
0197     }
0198   }
0199 
0200   private static Row createRecordSpark22000Row(Long index) {
0201     Object[] values = new Object[] {
0202             index.shortValue(),
0203             index.intValue(),
0204             index,
0205             index.floatValue(),
0206             index.doubleValue(),
0207             String.valueOf(index),
0208             index % 2 == 0,
0209             new java.sql.Timestamp(System.currentTimeMillis()),
0210             null
0211     };
0212     return new GenericRow(values);
0213   }
0214 
0215   private static String timestampToString(Timestamp ts) {
0216     String timestampString = String.valueOf(ts);
0217     String formatted = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(ts);
0218 
0219     if (timestampString.length() > 19 && !timestampString.substring(19).equals(".0")) {
0220       return formatted + timestampString.substring(19);
0221     } else {
0222       return formatted;
0223     }
0224   }
0225 
0226   private static RecordSpark22000 createRecordSpark22000(Row recordRow) {
0227     RecordSpark22000 record = new RecordSpark22000();
0228     record.setShortField(String.valueOf(recordRow.getShort(0)));
0229     record.setIntField(String.valueOf(recordRow.getInt(1)));
0230     record.setLongField(String.valueOf(recordRow.getLong(2)));
0231     record.setFloatField(String.valueOf(recordRow.getFloat(3)));
0232     record.setDoubleField(String.valueOf(recordRow.getDouble(4)));
0233     record.setStringField(recordRow.getString(5));
0234     record.setBooleanField(String.valueOf(recordRow.getBoolean(6)));
0235     record.setTimestampField(timestampToString(recordRow.getTimestamp(7)));
0236     // This would figure out that null value will not become "null".
0237     record.setNullIntField(null);
0238     return record;
0239   }
0240 
0241   private static Row createRecordSpark22000FailToUpcastRow(Long index) {
0242     Object[] values = new Object[] { String.valueOf(index) };
0243     return new GenericRow(values);
0244   }
0245 
0246   public static class ArrayRecord {
0247 
0248     private int id;
0249     private List<Interval> intervals;
0250     private int[] ints;
0251 
0252     public ArrayRecord() { }
0253 
0254     ArrayRecord(int id, List<Interval> intervals, int[] ints) {
0255       this.id = id;
0256       this.intervals = intervals;
0257       this.ints = ints;
0258     }
0259 
0260     public int getId() {
0261       return id;
0262     }
0263 
0264     public void setId(int id) {
0265       this.id = id;
0266     }
0267 
0268     public List<Interval> getIntervals() {
0269       return intervals;
0270     }
0271 
0272     public void setIntervals(List<Interval> intervals) {
0273       this.intervals = intervals;
0274     }
0275 
0276     public int[] getInts() {
0277       return ints;
0278     }
0279 
0280     public void setInts(int[] ints) {
0281       this.ints = ints;
0282     }
0283 
0284     @Override
0285     public int hashCode() {
0286       return id ^ Objects.hashCode(intervals) ^ Objects.hashCode(ints);
0287     }
0288 
0289     @Override
0290     public boolean equals(Object obj) {
0291       if (!(obj instanceof ArrayRecord)) return false;
0292       ArrayRecord other = (ArrayRecord) obj;
0293       return (other.id == this.id) && Objects.equals(other.intervals, this.intervals) &&
0294               Arrays.equals(other.ints, ints);
0295     }
0296 
0297     @Override
0298     public String toString() {
0299       return String.format("{ id: %d, intervals: %s, ints: %s }", id, intervals,
0300               Arrays.toString(ints));
0301     }
0302   }
0303 
0304   public static class MapRecord {
0305 
0306     private int id;
0307     private Map<String, Interval> intervals;
0308 
0309     public MapRecord() { }
0310 
0311     MapRecord(int id, Map<String, Interval> intervals) {
0312       this.id = id;
0313       this.intervals = intervals;
0314     }
0315 
0316     public int getId() {
0317       return id;
0318     }
0319 
0320     public void setId(int id) {
0321       this.id = id;
0322     }
0323 
0324     public Map<String, Interval> getIntervals() {
0325       return intervals;
0326     }
0327 
0328     public void setIntervals(Map<String, Interval> intervals) {
0329       this.intervals = intervals;
0330     }
0331 
0332     @Override
0333     public int hashCode() {
0334       return id ^ Objects.hashCode(intervals);
0335     }
0336 
0337     @Override
0338     public boolean equals(Object obj) {
0339       if (!(obj instanceof MapRecord)) return false;
0340       MapRecord other = (MapRecord) obj;
0341       return (other.id == this.id) && Objects.equals(other.intervals, this.intervals);
0342     }
0343 
0344     @Override
0345     public String toString() {
0346       return String.format("{ id: %d, intervals: %s }", id, intervals);
0347     }
0348   }
0349 
0350   public static class Interval {
0351 
0352     private long startTime;
0353     private long endTime;
0354 
0355     public Interval() { }
0356 
0357     Interval(long startTime, long endTime) {
0358       this.startTime = startTime;
0359       this.endTime = endTime;
0360     }
0361 
0362     public long getStartTime() {
0363       return startTime;
0364     }
0365 
0366     public void setStartTime(long startTime) {
0367       this.startTime = startTime;
0368     }
0369 
0370     public long getEndTime() {
0371       return endTime;
0372     }
0373 
0374     public void setEndTime(long endTime) {
0375       this.endTime = endTime;
0376     }
0377 
0378     @Override
0379     public int hashCode() {
0380       return Long.hashCode(startTime) ^ Long.hashCode(endTime);
0381     }
0382 
0383     @Override
0384     public boolean equals(Object obj) {
0385       if (!(obj instanceof Interval)) return false;
0386       Interval other = (Interval) obj;
0387       return (other.startTime == this.startTime) && (other.endTime == this.endTime);
0388     }
0389 
0390     @Override
0391     public String toString() {
0392       return String.format("[%d,%d]", startTime, endTime);
0393     }
0394   }
0395 
0396   public static final class RecordSpark22000 {
0397     private String shortField;
0398     private String intField;
0399     private String longField;
0400     private String floatField;
0401     private String doubleField;
0402     private String stringField;
0403     private String booleanField;
0404     private String timestampField;
0405     private String nullIntField;
0406 
0407     public RecordSpark22000() { }
0408 
0409     public String getShortField() {
0410       return shortField;
0411     }
0412 
0413     public void setShortField(String shortField) {
0414       this.shortField = shortField;
0415     }
0416 
0417     public String getIntField() {
0418       return intField;
0419     }
0420 
0421     public void setIntField(String intField) {
0422       this.intField = intField;
0423     }
0424 
0425     public String getLongField() {
0426       return longField;
0427     }
0428 
0429     public void setLongField(String longField) {
0430       this.longField = longField;
0431     }
0432 
0433     public String getFloatField() {
0434       return floatField;
0435     }
0436 
0437     public void setFloatField(String floatField) {
0438       this.floatField = floatField;
0439     }
0440 
0441     public String getDoubleField() {
0442       return doubleField;
0443     }
0444 
0445     public void setDoubleField(String doubleField) {
0446       this.doubleField = doubleField;
0447     }
0448 
0449     public String getStringField() {
0450       return stringField;
0451     }
0452 
0453     public void setStringField(String stringField) {
0454       this.stringField = stringField;
0455     }
0456 
0457     public String getBooleanField() {
0458       return booleanField;
0459     }
0460 
0461     public void setBooleanField(String booleanField) {
0462       this.booleanField = booleanField;
0463     }
0464 
0465     public String getTimestampField() {
0466       return timestampField;
0467     }
0468 
0469     public void setTimestampField(String timestampField) {
0470       this.timestampField = timestampField;
0471     }
0472 
0473     public String getNullIntField() {
0474       return nullIntField;
0475     }
0476 
0477     public void setNullIntField(String nullIntField) {
0478       this.nullIntField = nullIntField;
0479     }
0480 
0481     @Override
0482     public boolean equals(Object o) {
0483       if (this == o) return true;
0484       if (o == null || getClass() != o.getClass()) return false;
0485       RecordSpark22000 that = (RecordSpark22000) o;
0486       return Objects.equals(shortField, that.shortField) &&
0487               Objects.equals(intField, that.intField) &&
0488               Objects.equals(longField, that.longField) &&
0489               Objects.equals(floatField, that.floatField) &&
0490               Objects.equals(doubleField, that.doubleField) &&
0491               Objects.equals(stringField, that.stringField) &&
0492               Objects.equals(booleanField, that.booleanField) &&
0493               Objects.equals(timestampField, that.timestampField) &&
0494               Objects.equals(nullIntField, that.nullIntField);
0495     }
0496 
0497     @Override
0498     public int hashCode() {
0499       return Objects.hash(shortField, intField, longField, floatField, doubleField, stringField,
0500               booleanField, timestampField, nullIntField);
0501     }
0502 
0503     @Override
0504     public String toString() {
0505       return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
0506           .append("shortField", shortField)
0507           .append("intField", intField)
0508           .append("longField", longField)
0509           .append("floatField", floatField)
0510           .append("doubleField", doubleField)
0511           .append("stringField", stringField)
0512           .append("booleanField", booleanField)
0513           .append("timestampField", timestampField)
0514           .append("nullIntField", nullIntField)
0515           .toString();
0516     }
0517   }
0518 
0519   public static final class RecordSpark22000FailToUpcast {
0520     private Integer id;
0521 
0522     public RecordSpark22000FailToUpcast() {
0523     }
0524 
0525     public Integer getId() {
0526       return id;
0527     }
0528 
0529     public void setId(Integer id) {
0530       this.id = id;
0531     }
0532   }
0533 
0534   @Test
0535   public void testBeanWithLocalDateAndInstant() {
0536     String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key());
0537     try {
0538       spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true");
0539       List<Row> inputRows = new ArrayList<>();
0540       List<LocalDateInstantRecord> expectedRecords = new ArrayList<>();
0541 
0542       for (long idx = 0 ; idx < 5 ; idx++) {
0543         Row row = createLocalDateInstantRow(idx);
0544         inputRows.add(row);
0545         expectedRecords.add(createLocalDateInstantRecord(row));
0546       }
0547 
0548       Encoder<LocalDateInstantRecord> encoder = Encoders.bean(LocalDateInstantRecord.class);
0549 
0550       StructType schema = new StructType()
0551         .add("localDateField", DataTypes.DateType)
0552         .add("instantField", DataTypes.TimestampType);
0553 
0554       Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
0555       Dataset<LocalDateInstantRecord> dataset = dataFrame.as(encoder);
0556 
0557       List<LocalDateInstantRecord> records = dataset.collectAsList();
0558 
0559       Assert.assertEquals(expectedRecords, records);
0560     } finally {
0561         spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf);
0562     }
0563   }
0564 
0565   public static final class LocalDateInstantRecord {
0566     private String localDateField;
0567     private String instantField;
0568 
0569     public LocalDateInstantRecord() { }
0570 
0571     public String getLocalDateField() {
0572       return localDateField;
0573     }
0574 
0575     public void setLocalDateField(String localDateField) {
0576       this.localDateField = localDateField;
0577     }
0578 
0579     public String getInstantField() {
0580       return instantField;
0581     }
0582 
0583     public void setInstantField(String instantField) {
0584       this.instantField = instantField;
0585     }
0586 
0587     @Override
0588     public boolean equals(Object o) {
0589       if (this == o) return true;
0590       if (o == null || getClass() != o.getClass()) return false;
0591       LocalDateInstantRecord that = (LocalDateInstantRecord) o;
0592       return Objects.equals(localDateField, that.localDateField) &&
0593         Objects.equals(instantField, that.instantField);
0594     }
0595 
0596     @Override
0597     public int hashCode() {
0598       return Objects.hash(localDateField, instantField);
0599     }
0600 
0601     @Override
0602     public String toString() {
0603       return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
0604           .append("localDateField", localDateField)
0605           .append("instantField", instantField)
0606           .toString();
0607     }
0608 
0609   }
0610 
0611   private static Row createLocalDateInstantRow(Long index) {
0612     Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) };
0613     return new GenericRow(values);
0614   }
0615 
0616   private static LocalDateInstantRecord createLocalDateInstantRecord(Row recordRow) {
0617     LocalDateInstantRecord record = new LocalDateInstantRecord();
0618     record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0)));
0619     Instant instant = recordRow.getInstant(1);
0620     TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(
0621       DateTimeUtils.getZoneId(SQLConf.get().sessionLocalTimeZone()));
0622     record.setInstantField(formatter.format(DateTimeUtils.instantToMicros(instant)));
0623     return record;
0624   }
0625 }