0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 package org.apache.spark.examples.sql;
0018
0019
0020 import java.io.Serializable;
0021
0022 import org.apache.spark.sql.Dataset;
0023 import org.apache.spark.sql.Encoder;
0024 import org.apache.spark.sql.Encoders;
0025 import org.apache.spark.sql.SparkSession;
0026 import org.apache.spark.sql.TypedColumn;
0027 import org.apache.spark.sql.expressions.Aggregator;
0028
0029
0030 public class JavaUserDefinedTypedAggregation {
0031
0032
0033 public static class Employee implements Serializable {
0034 private String name;
0035 private long salary;
0036
0037
0038
0039 public String getName() {
0040 return name;
0041 }
0042
0043 public void setName(String name) {
0044 this.name = name;
0045 }
0046
0047 public long getSalary() {
0048 return salary;
0049 }
0050
0051 public void setSalary(long salary) {
0052 this.salary = salary;
0053 }
0054
0055 }
0056
0057 public static class Average implements Serializable {
0058 private long sum;
0059 private long count;
0060
0061
0062
0063 public Average() {
0064 }
0065
0066 public Average(long sum, long count) {
0067 this.sum = sum;
0068 this.count = count;
0069 }
0070
0071 public long getSum() {
0072 return sum;
0073 }
0074
0075 public void setSum(long sum) {
0076 this.sum = sum;
0077 }
0078
0079 public long getCount() {
0080 return count;
0081 }
0082
0083 public void setCount(long count) {
0084 this.count = count;
0085 }
0086
0087 }
0088
0089 public static class MyAverage extends Aggregator<Employee, Average, Double> {
0090
0091 public Average zero() {
0092 return new Average(0L, 0L);
0093 }
0094
0095
0096 public Average reduce(Average buffer, Employee employee) {
0097 long newSum = buffer.getSum() + employee.getSalary();
0098 long newCount = buffer.getCount() + 1;
0099 buffer.setSum(newSum);
0100 buffer.setCount(newCount);
0101 return buffer;
0102 }
0103
0104 public Average merge(Average b1, Average b2) {
0105 long mergedSum = b1.getSum() + b2.getSum();
0106 long mergedCount = b1.getCount() + b2.getCount();
0107 b1.setSum(mergedSum);
0108 b1.setCount(mergedCount);
0109 return b1;
0110 }
0111
0112 public Double finish(Average reduction) {
0113 return ((double) reduction.getSum()) / reduction.getCount();
0114 }
0115
0116 public Encoder<Average> bufferEncoder() {
0117 return Encoders.bean(Average.class);
0118 }
0119
0120 public Encoder<Double> outputEncoder() {
0121 return Encoders.DOUBLE();
0122 }
0123 }
0124
0125
0126 public static void main(String[] args) {
0127 SparkSession spark = SparkSession
0128 .builder()
0129 .appName("Java Spark SQL user-defined Datasets aggregation example")
0130 .getOrCreate();
0131
0132
0133 Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
0134 String path = "examples/src/main/resources/employees.json";
0135 Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
0136 ds.show();
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146 MyAverage myAverage = new MyAverage();
0147
0148 TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
0149 Dataset<Double> result = ds.select(averageSalary);
0150 result.show();
0151
0152
0153
0154
0155
0156
0157 spark.stop();
0158 }
0159
0160 }