0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 package org.apache.spark.examples.sql;
0018
0019
0020 import java.io.Serializable;
0021
0022 import org.apache.spark.sql.Dataset;
0023 import org.apache.spark.sql.Encoder;
0024 import org.apache.spark.sql.Encoders;
0025 import org.apache.spark.sql.Row;
0026 import org.apache.spark.sql.SparkSession;
0027 import org.apache.spark.sql.expressions.Aggregator;
0028 import org.apache.spark.sql.functions;
0029
0030
0031 public class JavaUserDefinedUntypedAggregation {
0032
0033
0034 public static class Average implements Serializable {
0035 private long sum;
0036 private long count;
0037
0038
0039
0040 public Average() {
0041 }
0042
0043 public Average(long sum, long count) {
0044 this.sum = sum;
0045 this.count = count;
0046 }
0047
0048 public long getSum() {
0049 return sum;
0050 }
0051
0052 public void setSum(long sum) {
0053 this.sum = sum;
0054 }
0055
0056 public long getCount() {
0057 return count;
0058 }
0059
0060 public void setCount(long count) {
0061 this.count = count;
0062 }
0063
0064 }
0065
0066 public static class MyAverage extends Aggregator<Long, Average, Double> {
0067
0068 public Average zero() {
0069 return new Average(0L, 0L);
0070 }
0071
0072
0073 public Average reduce(Average buffer, Long data) {
0074 long newSum = buffer.getSum() + data;
0075 long newCount = buffer.getCount() + 1;
0076 buffer.setSum(newSum);
0077 buffer.setCount(newCount);
0078 return buffer;
0079 }
0080
0081 public Average merge(Average b1, Average b2) {
0082 long mergedSum = b1.getSum() + b2.getSum();
0083 long mergedCount = b1.getCount() + b2.getCount();
0084 b1.setSum(mergedSum);
0085 b1.setCount(mergedCount);
0086 return b1;
0087 }
0088
0089 public Double finish(Average reduction) {
0090 return ((double) reduction.getSum()) / reduction.getCount();
0091 }
0092
0093 public Encoder<Average> bufferEncoder() {
0094 return Encoders.bean(Average.class);
0095 }
0096
0097 public Encoder<Double> outputEncoder() {
0098 return Encoders.DOUBLE();
0099 }
0100 }
0101
0102
0103 public static void main(String[] args) {
0104 SparkSession spark = SparkSession
0105 .builder()
0106 .appName("Java Spark SQL user-defined DataFrames aggregation example")
0107 .getOrCreate();
0108
0109
0110
0111 spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));
0112
0113 Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
0114 df.createOrReplaceTempView("employees");
0115 df.show();
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125 Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
0126 result.show();
0127
0128
0129
0130
0131
0132
0133
0134 spark.stop();
0135 }
0136 }