Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 package org.apache.spark.examples.sql;
0018 
0019 // $example on:typed_custom_aggregation$
0020 import java.io.Serializable;
0021 
0022 import org.apache.spark.sql.Dataset;
0023 import org.apache.spark.sql.Encoder;
0024 import org.apache.spark.sql.Encoders;
0025 import org.apache.spark.sql.SparkSession;
0026 import org.apache.spark.sql.TypedColumn;
0027 import org.apache.spark.sql.expressions.Aggregator;
0028 // $example off:typed_custom_aggregation$
0029 
0030 public class JavaUserDefinedTypedAggregation {
0031 
0032   // $example on:typed_custom_aggregation$
0033   public static class Employee implements Serializable {
0034     private String name;
0035     private long salary;
0036 
0037     // Constructors, getters, setters...
0038     // $example off:typed_custom_aggregation$
0039     public String getName() {
0040       return name;
0041     }
0042 
0043     public void setName(String name) {
0044       this.name = name;
0045     }
0046 
0047     public long getSalary() {
0048       return salary;
0049     }
0050 
0051     public void setSalary(long salary) {
0052       this.salary = salary;
0053     }
0054     // $example on:typed_custom_aggregation$
0055   }
0056 
0057   public static class Average implements Serializable  {
0058     private long sum;
0059     private long count;
0060 
0061     // Constructors, getters, setters...
0062     // $example off:typed_custom_aggregation$
0063     public Average() {
0064     }
0065 
0066     public Average(long sum, long count) {
0067       this.sum = sum;
0068       this.count = count;
0069     }
0070 
0071     public long getSum() {
0072       return sum;
0073     }
0074 
0075     public void setSum(long sum) {
0076       this.sum = sum;
0077     }
0078 
0079     public long getCount() {
0080       return count;
0081     }
0082 
0083     public void setCount(long count) {
0084       this.count = count;
0085     }
0086     // $example on:typed_custom_aggregation$
0087   }
0088 
0089   public static class MyAverage extends Aggregator<Employee, Average, Double> {
0090     // A zero value for this aggregation. Should satisfy the property that any b + zero = b
0091     public Average zero() {
0092       return new Average(0L, 0L);
0093     }
0094     // Combine two values to produce a new value. For performance, the function may modify `buffer`
0095     // and return it instead of constructing a new object
0096     public Average reduce(Average buffer, Employee employee) {
0097       long newSum = buffer.getSum() + employee.getSalary();
0098       long newCount = buffer.getCount() + 1;
0099       buffer.setSum(newSum);
0100       buffer.setCount(newCount);
0101       return buffer;
0102     }
0103     // Merge two intermediate values
0104     public Average merge(Average b1, Average b2) {
0105       long mergedSum = b1.getSum() + b2.getSum();
0106       long mergedCount = b1.getCount() + b2.getCount();
0107       b1.setSum(mergedSum);
0108       b1.setCount(mergedCount);
0109       return b1;
0110     }
0111     // Transform the output of the reduction
0112     public Double finish(Average reduction) {
0113       return ((double) reduction.getSum()) / reduction.getCount();
0114     }
0115     // Specifies the Encoder for the intermediate value type
0116     public Encoder<Average> bufferEncoder() {
0117       return Encoders.bean(Average.class);
0118     }
0119     // Specifies the Encoder for the final output value type
0120     public Encoder<Double> outputEncoder() {
0121       return Encoders.DOUBLE();
0122     }
0123   }
0124   // $example off:typed_custom_aggregation$
0125 
0126   public static void main(String[] args) {
0127     SparkSession spark = SparkSession
0128       .builder()
0129       .appName("Java Spark SQL user-defined Datasets aggregation example")
0130       .getOrCreate();
0131 
0132     // $example on:typed_custom_aggregation$
0133     Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
0134     String path = "examples/src/main/resources/employees.json";
0135     Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
0136     ds.show();
0137     // +-------+------+
0138     // |   name|salary|
0139     // +-------+------+
0140     // |Michael|  3000|
0141     // |   Andy|  4500|
0142     // | Justin|  3500|
0143     // |  Berta|  4000|
0144     // +-------+------+
0145 
0146     MyAverage myAverage = new MyAverage();
0147     // Convert the function to a `TypedColumn` and give it a name
0148     TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
0149     Dataset<Double> result = ds.select(averageSalary);
0150     result.show();
0151     // +--------------+
0152     // |average_salary|
0153     // +--------------+
0154     // |        3750.0|
0155     // +--------------+
0156     // $example off:typed_custom_aggregation$
0157     spark.stop();
0158   }
0159 
0160 }