Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql;
0019 
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 
0023 import org.apache.spark.sql.Row;
0024 import org.apache.spark.sql.expressions.MutableAggregationBuffer;
0025 import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
0026 import org.apache.spark.sql.types.DataType;
0027 import org.apache.spark.sql.types.DataTypes;
0028 import org.apache.spark.sql.types.StructField;
0029 import org.apache.spark.sql.types.StructType;
0030 
0031 /**
0032  * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
0033  * {@link org.apache.spark.sql.types.DoubleType} column.
0034  */
0035 public class MyDoubleSum extends UserDefinedAggregateFunction {
0036 
0037   private StructType _inputDataType;
0038 
0039   private StructType _bufferSchema;
0040 
0041   private DataType _returnDataType;
0042 
0043   public MyDoubleSum() {
0044     List<StructField> inputFields = new ArrayList<>();
0045     inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
0046     _inputDataType = DataTypes.createStructType(inputFields);
0047 
0048     List<StructField> bufferFields = new ArrayList<>();
0049     bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
0050     _bufferSchema = DataTypes.createStructType(bufferFields);
0051 
0052     _returnDataType = DataTypes.DoubleType;
0053   }
0054 
0055   @Override public StructType inputSchema() {
0056     return _inputDataType;
0057   }
0058 
0059   @Override public StructType bufferSchema() {
0060     return _bufferSchema;
0061   }
0062 
0063   @Override public DataType dataType() {
0064     return _returnDataType;
0065   }
0066 
0067   @Override public boolean deterministic() {
0068     return true;
0069   }
0070 
0071   @Override public void initialize(MutableAggregationBuffer buffer) {
0072     // The initial value of the sum is null.
0073     buffer.update(0, null);
0074   }
0075 
0076   @Override public void update(MutableAggregationBuffer buffer, Row input) {
0077     // This input Row only has a single column storing the input value in Double.
0078     // We only update the buffer when the input value is not null.
0079     if (!input.isNullAt(0)) {
0080       if (buffer.isNullAt(0)) {
0081         // If the buffer value (the intermediate result of the sum) is still null,
0082         // we set the input value to the buffer.
0083         buffer.update(0, input.getDouble(0));
0084       } else {
0085         // Otherwise, we add the input value to the buffer value.
0086         Double newValue = input.getDouble(0) + buffer.getDouble(0);
0087         buffer.update(0, newValue);
0088       }
0089     }
0090   }
0091 
0092   @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
0093     // buffer1 and buffer2 have the same structure.
0094     // We only update the buffer1 when the input buffer2's value is not null.
0095     if (!buffer2.isNullAt(0)) {
0096       if (buffer1.isNullAt(0)) {
0097         // If the buffer value (intermediate result of the sum) is still null,
0098         // we set the it as the input buffer's value.
0099         buffer1.update(0, buffer2.getDouble(0));
0100       } else {
0101         // Otherwise, we add the input buffer's value (buffer1) to the mutable
0102         // buffer's value (buffer2).
0103         Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
0104         buffer1.update(0, newValue);
0105       }
0106     }
0107   }
0108 
0109   @Override public Object evaluate(Row buffer) {
0110     if (buffer.isNullAt(0)) {
0111       // If the buffer value is still null, we return null.
0112       return null;
0113     } else {
0114       // Otherwise, the intermediate sum is the final result.
0115       return buffer.getDouble(0);
0116     }
0117   }
0118 }