0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.util.ArrayList;
0021 import java.util.List;
0022
0023 import org.apache.spark.sql.Row;
0024 import org.apache.spark.sql.expressions.MutableAggregationBuffer;
0025 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
0026 import org.apache.spark.sql.types.DataType;
0027 import org.apache.spark.sql.types.DataTypes;
0028 import org.apache.spark.sql.types.StructField;
0029 import org.apache.spark.sql.types.StructType;
0030
0031
0032
0033
0034
0035 public class MyDoubleSum extends UserDefinedAggregateFunction {
0036
0037 private StructType _inputDataType;
0038
0039 private StructType _bufferSchema;
0040
0041 private DataType _returnDataType;
0042
0043 public MyDoubleSum() {
0044 List<StructField> inputFields = new ArrayList<>();
0045 inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
0046 _inputDataType = DataTypes.createStructType(inputFields);
0047
0048 List<StructField> bufferFields = new ArrayList<>();
0049 bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
0050 _bufferSchema = DataTypes.createStructType(bufferFields);
0051
0052 _returnDataType = DataTypes.DoubleType;
0053 }
0054
0055 @Override public StructType inputSchema() {
0056 return _inputDataType;
0057 }
0058
0059 @Override public StructType bufferSchema() {
0060 return _bufferSchema;
0061 }
0062
0063 @Override public DataType dataType() {
0064 return _returnDataType;
0065 }
0066
0067 @Override public boolean deterministic() {
0068 return true;
0069 }
0070
0071 @Override public void initialize(MutableAggregationBuffer buffer) {
0072
0073 buffer.update(0, null);
0074 }
0075
0076 @Override public void update(MutableAggregationBuffer buffer, Row input) {
0077
0078
0079 if (!input.isNullAt(0)) {
0080 if (buffer.isNullAt(0)) {
0081
0082
0083 buffer.update(0, input.getDouble(0));
0084 } else {
0085
0086 Double newValue = input.getDouble(0) + buffer.getDouble(0);
0087 buffer.update(0, newValue);
0088 }
0089 }
0090 }
0091
0092 @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
0093
0094
0095 if (!buffer2.isNullAt(0)) {
0096 if (buffer1.isNullAt(0)) {
0097
0098
0099 buffer1.update(0, buffer2.getDouble(0));
0100 } else {
0101
0102
0103 Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
0104 buffer1.update(0, newValue);
0105 }
0106 }
0107 }
0108
0109 @Override public Object evaluate(Row buffer) {
0110 if (buffer.isNullAt(0)) {
0111
0112 return null;
0113 } else {
0114
0115 return buffer.getDouble(0);
0116 }
0117 }
0118 }