0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements. See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License. You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017
0018 # group.R - GroupedData class and methods implemented in S4 OO classes
0019
0020 #' @include generics.R jobj.R schema.R column.R
0021 NULL
0022
0023 setOldClass("jobj")
0024
0025 #' S4 class that represents a GroupedData
0026 #'
0027 #' GroupedDatas can be created using groupBy() on a SparkDataFrame
0028 #'
0029 #' @rdname GroupedData
0030 #' @seealso groupBy
0031 #'
0032 #' @param sgd A Java object reference to the backing Scala GroupedData
0033 #' @note GroupedData since 1.4.0
0034 setClass("GroupedData",
0035 slots = list(sgd = "jobj"))
0036
0037 setMethod("initialize", "GroupedData", function(.Object, sgd) {
0038 .Object@sgd <- sgd
0039 .Object
0040 })
0041
0042 #' @rdname GroupedData
0043 groupedData <- function(sgd) {
0044 new("GroupedData", sgd)
0045 }
0046
0047
0048 #' @rdname show
0049 #' @aliases show,GroupedData-method
0050 #' @note show(GroupedData) since 1.4.0
0051 setMethod("show", "GroupedData",
0052 function(object) {
0053 cat("GroupedData\n")
0054 })
0055
0056 #' Count
0057 #'
0058 #' Count the number of rows for each group when we have \code{GroupedData} input.
0059 #' The resulting SparkDataFrame will also contain the grouping columns.
0060 #'
0061 #' @return A SparkDataFrame.
0062 #' @rdname count
0063 #' @aliases count,GroupedData-method
0064 #' @examples
0065 #' \dontrun{
0066 #' count(groupBy(df, "name"))
0067 #' }
0068 #' @note count since 1.4.0
0069 setMethod("count",
0070 signature(x = "GroupedData"),
0071 function(x) {
0072 dataFrame(callJMethod(x@sgd, "count"))
0073 })
0074
0075 #' summarize
0076 #'
0077 #' Aggregates on the entire SparkDataFrame without groups.
0078 #' The resulting SparkDataFrame will also contain the grouping columns.
0079 #'
0080 #' df2 <- agg(df, <column> = <aggFunction>)
0081 #' df2 <- agg(df, newColName = aggFunction(column))
0082 #'
0083 #' @rdname summarize
0084 #' @aliases agg,GroupedData-method
0085 #' @name agg
0086 #' @family agg_funcs
0087 #' @examples
0088 #' \dontrun{
0089 #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
0090 #' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
0091 #' df4 <- summarize(df, ageSum = max(df$age))
0092 #' }
0093 #' @note agg since 1.4.0
0094 setMethod("agg",
0095 signature(x = "GroupedData"),
0096 function(x, ...) {
0097 cols <- list(...)
0098 stopifnot(length(cols) > 0)
0099 if (is.character(cols[[1]])) {
0100 cols <- varargsToEnv(...)
0101 sdf <- callJMethod(x@sgd, "agg", cols)
0102 } else if (class(cols[[1]]) == "Column") {
0103 ns <- names(cols)
0104 if (!is.null(ns)) {
0105 for (n in ns) {
0106 if (n != "") {
0107 cols[[n]] <- alias(cols[[n]], n)
0108 }
0109 }
0110 }
0111 jcols <- lapply(cols, function(c) { c@jc })
0112 sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
0113 } else {
0114 stop("agg can only support Column or character")
0115 }
0116 dataFrame(sdf)
0117 })
0118
0119 #' @rdname summarize
0120 #' @name summarize
0121 #' @aliases summarize,GroupedData-method
0122 #' @note summarize since 1.4.0
0123 setMethod("summarize",
0124 signature(x = "GroupedData"),
0125 function(x, ...) {
0126 agg(x, ...)
0127 })
0128
0129 # Aggregate Functions by name
0130 methods <- c("avg", "max", "mean", "min", "sum")
0131
0132 # These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp",
0133 # "stddev_pop", "variance", "var_samp", "var_pop"
0134
0135 #' Pivot a column of the GroupedData and perform the specified aggregation.
0136 #'
0137 #' Pivot a column of the GroupedData and perform the specified aggregation.
0138 #' There are two versions of pivot function: one that requires the caller to specify the list
0139 #' of distinct values to pivot on, and one that does not. The latter is more concise but less
0140 #' efficient, because Spark needs to first compute the list of distinct values internally.
0141 #'
0142 #' @param x a GroupedData object
0143 #' @param colname A column name
0144 #' @param values A value or a list/vector of distinct values for the output columns.
0145 #' @return GroupedData object
0146 #' @rdname pivot
0147 #' @aliases pivot,GroupedData,character-method
0148 #' @name pivot
0149 #' @examples
0150 #' \dontrun{
0151 #' df <- createDataFrame(data.frame(
0152 #' earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
0153 #' course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
0154 #' period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"),
0155 #' year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016)
0156 #' ))
0157 #' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings")
0158 #' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings")
0159 #' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings")
0160 #' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings")
0161 #' }
0162 #' @note pivot since 2.0.0
0163 setMethod("pivot",
0164 signature(x = "GroupedData", colname = "character"),
0165 function(x, colname, values = list()) {
0166 stopifnot(length(colname) == 1)
0167 if (length(values) == 0) {
0168 result <- callJMethod(x@sgd, "pivot", colname)
0169 } else {
0170 if (length(values) > length(unique(values))) {
0171 stop("Values are not unique")
0172 }
0173 result <- callJMethod(x@sgd, "pivot", colname, as.list(values))
0174 }
0175 groupedData(result)
0176 })
0177
0178 createMethod <- function(name) {
0179 setMethod(name,
0180 signature(x = "GroupedData"),
0181 function(x, ...) {
0182 sdf <- callJMethod(x@sgd, name, list(...))
0183 dataFrame(sdf)
0184 })
0185 }
0186
0187 createMethods <- function() {
0188 for (name in methods) {
0189 createMethod(name)
0190 }
0191 }
0192
0193 createMethods()
0194
0195 #' gapply
0196 #'
0197 #' @rdname gapply
0198 #' @aliases gapply,GroupedData-method
0199 #' @name gapply
0200 #' @note gapply(GroupedData) since 2.0.0
0201 setMethod("gapply",
0202 signature(x = "GroupedData"),
0203 function(x, func, schema) {
0204 if (is.null(schema)) stop("schema cannot be NULL")
0205 gapplyInternal(x, func, schema)
0206 })
0207
0208 #' gapplyCollect
0209 #'
0210 #' @rdname gapplyCollect
0211 #' @aliases gapplyCollect,GroupedData-method
0212 #' @name gapplyCollect
0213 #' @note gapplyCollect(GroupedData) since 2.0.0
0214 setMethod("gapplyCollect",
0215 signature(x = "GroupedData"),
0216 function(x, func) {
0217 gdf <- gapplyInternal(x, func, NULL)
0218 content <- callJMethod(gdf@sdf, "collect")
0219 # content is a list of items of struct type. Each item has a single field
0220 # which is a serialized data.frame corresponds to one group of the
0221 # SparkDataFrame.
0222 ldfs <- lapply(content, function(x) { unserialize(x[[1]]) })
0223 ldf <- do.call(rbind, ldfs)
0224 row.names(ldf) <- NULL
0225 ldf
0226 })
0227
0228 gapplyInternal <- function(x, func, schema) {
0229 if (is.character(schema)) {
0230 schema <- structType(schema)
0231 }
0232 arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]] == "true"
0233 if (arrowEnabled) {
0234 if (inherits(schema, "structType")) {
0235 checkSchemaInArrow(schema)
0236 } else if (is.null(schema)) {
0237 stop("Arrow optimization does not support 'gapplyCollect' yet. Please disable ",
0238 "Arrow optimization or use 'collect' and 'gapply' APIs instead.")
0239 } else {
0240 stop("'schema' should be DDL-formatted string or structType.")
0241 }
0242 }
0243
0244 packageNamesArr <- serialize(.sparkREnv[[".packages"]],
0245 connection = NULL)
0246 broadcastArr <- lapply(ls(.broadcastNames),
0247 function(name) { get(name, .broadcastNames) })
0248 sdf <- callJStatic(
0249 "org.apache.spark.sql.api.r.SQLUtils",
0250 "gapply",
0251 x@sgd,
0252 serialize(cleanClosure(func), connection = NULL),
0253 packageNamesArr,
0254 broadcastArr,
0255 if (class(schema) == "structType") { schema$jobj } else { NULL })
0256 dataFrame(sdf)
0257 }