Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 package org.apache.spark.examples.sql;
0018 
0019 // $example on:untyped_custom_aggregation$
0020 import java.io.Serializable;
0021 
0022 import org.apache.spark.sql.Dataset;
0023 import org.apache.spark.sql.Encoder;
0024 import org.apache.spark.sql.Encoders;
0025 import org.apache.spark.sql.Row;
0026 import org.apache.spark.sql.SparkSession;
0027 import org.apache.spark.sql.expressions.Aggregator;
0028 import org.apache.spark.sql.functions;
0029 // $example off:untyped_custom_aggregation$
0030 
0031 public class JavaUserDefinedUntypedAggregation {
0032 
0033   // $example on:untyped_custom_aggregation$
0034   public static class Average implements Serializable  {
0035     private long sum;
0036     private long count;
0037 
0038     // Constructors, getters, setters...
0039     // $example off:typed_custom_aggregation$
0040     public Average() {
0041     }
0042 
0043     public Average(long sum, long count) {
0044       this.sum = sum;
0045       this.count = count;
0046     }
0047 
0048     public long getSum() {
0049       return sum;
0050     }
0051 
0052     public void setSum(long sum) {
0053       this.sum = sum;
0054     }
0055 
0056     public long getCount() {
0057       return count;
0058     }
0059 
0060     public void setCount(long count) {
0061       this.count = count;
0062     }
0063     // $example on:typed_custom_aggregation$
0064   }
0065 
0066   public static class MyAverage extends Aggregator<Long, Average, Double> {
0067     // A zero value for this aggregation. Should satisfy the property that any b + zero = b
0068     public Average zero() {
0069       return new Average(0L, 0L);
0070     }
0071     // Combine two values to produce a new value. For performance, the function may modify `buffer`
0072     // and return it instead of constructing a new object
0073     public Average reduce(Average buffer, Long data) {
0074       long newSum = buffer.getSum() + data;
0075       long newCount = buffer.getCount() + 1;
0076       buffer.setSum(newSum);
0077       buffer.setCount(newCount);
0078       return buffer;
0079     }
0080     // Merge two intermediate values
0081     public Average merge(Average b1, Average b2) {
0082       long mergedSum = b1.getSum() + b2.getSum();
0083       long mergedCount = b1.getCount() + b2.getCount();
0084       b1.setSum(mergedSum);
0085       b1.setCount(mergedCount);
0086       return b1;
0087     }
0088     // Transform the output of the reduction
0089     public Double finish(Average reduction) {
0090       return ((double) reduction.getSum()) / reduction.getCount();
0091     }
0092     // Specifies the Encoder for the intermediate value type
0093     public Encoder<Average> bufferEncoder() {
0094       return Encoders.bean(Average.class);
0095     }
0096     // Specifies the Encoder for the final output value type
0097     public Encoder<Double> outputEncoder() {
0098       return Encoders.DOUBLE();
0099     }
0100   }
0101   // $example off:untyped_custom_aggregation$
0102 
0103   public static void main(String[] args) {
0104     SparkSession spark = SparkSession
0105       .builder()
0106       .appName("Java Spark SQL user-defined DataFrames aggregation example")
0107       .getOrCreate();
0108 
0109     // $example on:untyped_custom_aggregation$
0110     // Register the function to access it
0111     spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));
0112 
0113     Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
0114     df.createOrReplaceTempView("employees");
0115     df.show();
0116     // +-------+------+
0117     // |   name|salary|
0118     // +-------+------+
0119     // |Michael|  3000|
0120     // |   Andy|  4500|
0121     // | Justin|  3500|
0122     // |  Berta|  4000|
0123     // +-------+------+
0124 
0125     Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
0126     result.show();
0127     // +--------------+
0128     // |average_salary|
0129     // +--------------+
0130     // |        3750.0|
0131     // +--------------+
0132     // $example off:untyped_custom_aggregation$
0133 
0134     spark.stop();
0135   }
0136 }