0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql;
0019
0020 import java.util.ArrayList;
0021 import java.util.List;
0022
0023 import org.apache.spark.sql.Row;
0024 import org.apache.spark.sql.expressions.MutableAggregationBuffer;
0025 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
0026 import org.apache.spark.sql.types.DataType;
0027 import org.apache.spark.sql.types.DataTypes;
0028 import org.apache.spark.sql.types.StructField;
0029 import org.apache.spark.sql.types.StructType;
0030
0031
0032
0033
0034
0035
0036 public class MyDoubleAvg extends UserDefinedAggregateFunction {
0037
0038 private StructType _inputDataType;
0039
0040 private StructType _bufferSchema;
0041
0042 private DataType _returnDataType;
0043
0044 public MyDoubleAvg() {
0045 List<StructField> inputFields = new ArrayList<>();
0046 inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
0047 _inputDataType = DataTypes.createStructType(inputFields);
0048
0049
0050
0051
0052 List<StructField> bufferFields = new ArrayList<>();
0053 bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
0054 bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
0055 _bufferSchema = DataTypes.createStructType(bufferFields);
0056
0057 _returnDataType = DataTypes.DoubleType;
0058 }
0059
0060 @Override public StructType inputSchema() {
0061 return _inputDataType;
0062 }
0063
0064 @Override public StructType bufferSchema() {
0065 return _bufferSchema;
0066 }
0067
0068 @Override public DataType dataType() {
0069 return _returnDataType;
0070 }
0071
0072 @Override public boolean deterministic() {
0073 return true;
0074 }
0075
0076 @Override public void initialize(MutableAggregationBuffer buffer) {
0077
0078 buffer.update(0, null);
0079
0080 buffer.update(1, 0L);
0081 }
0082
0083 @Override public void update(MutableAggregationBuffer buffer, Row input) {
0084
0085
0086 if (!input.isNullAt(0)) {
0087
0088
0089 if (buffer.isNullAt(0)) {
0090 buffer.update(0, input.getDouble(0));
0091 buffer.update(1, 1L);
0092 } else {
0093
0094 Double newValue = input.getDouble(0) + buffer.getDouble(0);
0095 buffer.update(0, newValue);
0096 buffer.update(1, buffer.getLong(1) + 1L);
0097 }
0098 }
0099 }
0100
0101 @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
0102
0103
0104 if (!buffer2.isNullAt(0)) {
0105 if (buffer1.isNullAt(0)) {
0106
0107
0108 buffer1.update(0, buffer2.getDouble(0));
0109 buffer1.update(1, buffer2.getLong(1));
0110 } else {
0111
0112 Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
0113 buffer1.update(0, newValue);
0114 buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
0115 }
0116 }
0117 }
0118
0119 @Override public Object evaluate(Row buffer) {
0120 if (buffer.isNullAt(0)) {
0121
0122
0123 return null;
0124 } else {
0125
0126 return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
0127 }
0128 }
0129 }