Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 
0023 import org.apache.spark.sql.Row;
0024 import org.apache.spark.sql.expressions.MutableAggregationBuffer;
0025 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
0026 import org.apache.spark.sql.types.DataType;
0027 import org.apache.spark.sql.types.DataTypes;
0028 import org.apache.spark.sql.types.StructField;
0029 import org.apache.spark.sql.types.StructType;
0030 
0031 /**
0032  * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a
0033  * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum
0034  * of the average value of input values and 100.0.
0035  */
0036 public class MyDoubleAvg extends UserDefinedAggregateFunction {
0037 
0038   private StructType _inputDataType;
0039 
0040   private StructType _bufferSchema;
0041 
0042   private DataType _returnDataType;
0043 
0044   public MyDoubleAvg() {
0045     List<StructField> inputFields = new ArrayList<>();
0046     inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
0047     _inputDataType = DataTypes.createStructType(inputFields);
0048 
0049     // The buffer has two values, bufferSum for storing the current sum and
0050     // bufferCount for storing the number of non-null input values that have been contributed
0051     // to the current sum.
0052     List<StructField> bufferFields = new ArrayList<>();
0053     bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
0054     bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
0055     _bufferSchema = DataTypes.createStructType(bufferFields);
0056 
0057     _returnDataType = DataTypes.DoubleType;
0058   }
0059 
0060   @Override public StructType inputSchema() {
0061     return _inputDataType;
0062   }
0063 
0064   @Override public StructType bufferSchema() {
0065     return _bufferSchema;
0066   }
0067 
0068   @Override public DataType dataType() {
0069     return _returnDataType;
0070   }
0071 
0072   @Override public boolean deterministic() {
0073     return true;
0074   }
0075 
0076   @Override public void initialize(MutableAggregationBuffer buffer) {
0077     // The initial value of the sum is null.
0078     buffer.update(0, null);
0079     // The initial value of the count is 0.
0080     buffer.update(1, 0L);
0081   }
0082 
0083   @Override public void update(MutableAggregationBuffer buffer, Row input) {
0084     // This input Row only has a single column storing the input value in Double.
0085     // We only update the buffer when the input value is not null.
0086     if (!input.isNullAt(0)) {
0087       // If the buffer value (the intermediate result of the sum) is still null,
0088       // we set the input value to the buffer and set the bufferCount to 1.
0089       if (buffer.isNullAt(0)) {
0090         buffer.update(0, input.getDouble(0));
0091         buffer.update(1, 1L);
0092       } else {
0093         // Otherwise, update the bufferSum and increment bufferCount.
0094         Double newValue = input.getDouble(0) + buffer.getDouble(0);
0095         buffer.update(0, newValue);
0096         buffer.update(1, buffer.getLong(1) + 1L);
0097       }
0098     }
0099   }
0100 
0101   @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
0102     // buffer1 and buffer2 have the same structure.
0103     // We only update the buffer1 when the input buffer2's sum value is not null.
0104     if (!buffer2.isNullAt(0)) {
0105       if (buffer1.isNullAt(0)) {
0106         // If the buffer value (intermediate result of the sum) is still null,
0107         // we set the it as the input buffer's value.
0108         buffer1.update(0, buffer2.getDouble(0));
0109         buffer1.update(1, buffer2.getLong(1));
0110       } else {
0111         // Otherwise, we update the bufferSum and bufferCount.
0112         Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
0113         buffer1.update(0, newValue);
0114         buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
0115       }
0116     }
0117   }
0118 
0119   @Override public Object evaluate(Row buffer) {
0120     if (buffer.isNullAt(0)) {
0121       // If the bufferSum is still null, we return null because this function has not got
0122       // any input row.
0123       return null;
0124     } else {
0125       // Otherwise, we calculate the special average value.
0126       return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
0127     }
0128   }
0129 }