Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # group.R - GroupedData class and methods implemented in S4 OO classes
0019 
0020 #' @include generics.R jobj.R schema.R column.R
0021 NULL
0022 
0023 setOldClass("jobj")
0024 
0025 #' S4 class that represents a GroupedData
0026 #'
0027 #' GroupedDatas can be created using groupBy() on a SparkDataFrame
0028 #'
0029 #' @rdname GroupedData
0030 #' @seealso groupBy
0031 #'
0032 #' @param sgd A Java object reference to the backing Scala GroupedData
0033 #' @note GroupedData since 1.4.0
0034 setClass("GroupedData",
0035          slots = list(sgd = "jobj"))
0036 
0037 setMethod("initialize", "GroupedData", function(.Object, sgd) {
0038   .Object@sgd <- sgd
0039   .Object
0040 })
0041 
0042 #' @rdname GroupedData
0043 groupedData <- function(sgd) {
0044   new("GroupedData", sgd)
0045 }
0046 
0047 
0048 #' @rdname show
0049 #' @aliases show,GroupedData-method
0050 #' @note show(GroupedData) since 1.4.0
0051 setMethod("show", "GroupedData",
0052           function(object) {
0053             cat("GroupedData\n")
0054           })
0055 
0056 #' Count
0057 #'
0058 #' Count the number of rows for each group when we have \code{GroupedData} input.
0059 #' The resulting SparkDataFrame will also contain the grouping columns.
0060 #'
0061 #' @return A SparkDataFrame.
0062 #' @rdname count
0063 #' @aliases count,GroupedData-method
0064 #' @examples
0065 #' \dontrun{
0066 #'   count(groupBy(df, "name"))
0067 #' }
0068 #' @note count since 1.4.0
0069 setMethod("count",
0070           signature(x = "GroupedData"),
0071           function(x) {
0072             dataFrame(callJMethod(x@sgd, "count"))
0073           })
0074 
0075 #' summarize
0076 #'
0077 #' Aggregates on the entire SparkDataFrame without groups.
0078 #' The resulting SparkDataFrame will also contain the grouping columns.
0079 #'
0080 #' df2 <- agg(df, <column> = <aggFunction>)
0081 #' df2 <- agg(df, newColName = aggFunction(column))
0082 #'
0083 #' @rdname summarize
0084 #' @aliases agg,GroupedData-method
0085 #' @name agg
0086 #' @family agg_funcs
0087 #' @examples
0088 #' \dontrun{
0089 #'  df2 <- agg(df, age = "sum")  # new column name will be created as 'SUM(age#0)'
0090 #'  df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
0091 #'  df4 <- summarize(df, ageSum = max(df$age))
0092 #' }
0093 #' @note agg since 1.4.0
0094 setMethod("agg",
0095           signature(x = "GroupedData"),
0096           function(x, ...) {
0097             cols <- list(...)
0098             stopifnot(length(cols) > 0)
0099             if (is.character(cols[[1]])) {
0100               cols <- varargsToEnv(...)
0101               sdf <- callJMethod(x@sgd, "agg", cols)
0102             } else if (class(cols[[1]]) == "Column") {
0103               ns <- names(cols)
0104               if (!is.null(ns)) {
0105                 for (n in ns) {
0106                   if (n != "") {
0107                     cols[[n]] <- alias(cols[[n]], n)
0108                   }
0109                 }
0110               }
0111               jcols <- lapply(cols, function(c) { c@jc })
0112               sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
0113             } else {
0114               stop("agg can only support Column or character")
0115             }
0116             dataFrame(sdf)
0117           })
0118 
0119 #' @rdname summarize
0120 #' @name summarize
0121 #' @aliases summarize,GroupedData-method
0122 #' @note summarize since 1.4.0
0123 setMethod("summarize",
0124           signature(x = "GroupedData"),
0125           function(x, ...) {
0126             agg(x, ...)
0127           })
0128 
0129 # Aggregate Functions by name
0130 methods <- c("avg", "max", "mean", "min", "sum")
0131 
0132 # These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp",
0133 # "stddev_pop", "variance", "var_samp", "var_pop"
0134 
0135 #' Pivot a column of the GroupedData and perform the specified aggregation.
0136 #'
0137 #' Pivot a column of the GroupedData and perform the specified aggregation.
0138 #' There are two versions of pivot function: one that requires the caller to specify the list
0139 #' of distinct values to pivot on, and one that does not. The latter is more concise but less
0140 #' efficient, because Spark needs to first compute the list of distinct values internally.
0141 #'
0142 #' @param x a GroupedData object
0143 #' @param colname A column name
0144 #' @param values A value or a list/vector of distinct values for the output columns.
0145 #' @return GroupedData object
0146 #' @rdname pivot
0147 #' @aliases pivot,GroupedData,character-method
0148 #' @name pivot
0149 #' @examples
0150 #' \dontrun{
0151 #' df <- createDataFrame(data.frame(
0152 #'     earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000),
0153 #'     course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"),
0154 #'     period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"),
0155 #'     year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016)
0156 #' ))
0157 #' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings")
0158 #' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings")
0159 #' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings")
0160 #' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings")
0161 #' }
0162 #' @note pivot since 2.0.0
0163 setMethod("pivot",
0164           signature(x = "GroupedData", colname = "character"),
0165           function(x, colname, values = list()) {
0166             stopifnot(length(colname) == 1)
0167             if (length(values) == 0) {
0168               result <- callJMethod(x@sgd, "pivot", colname)
0169             } else {
0170               if (length(values) > length(unique(values))) {
0171                 stop("Values are not unique")
0172               }
0173               result <- callJMethod(x@sgd, "pivot", colname, as.list(values))
0174             }
0175             groupedData(result)
0176           })
0177 
0178 createMethod <- function(name) {
0179   setMethod(name,
0180             signature(x = "GroupedData"),
0181             function(x, ...) {
0182               sdf <- callJMethod(x@sgd, name, list(...))
0183               dataFrame(sdf)
0184             })
0185 }
0186 
0187 createMethods <- function() {
0188   for (name in methods) {
0189     createMethod(name)
0190   }
0191 }
0192 
0193 createMethods()
0194 
0195 #' gapply
0196 #'
0197 #' @rdname gapply
0198 #' @aliases gapply,GroupedData-method
0199 #' @name gapply
0200 #' @note gapply(GroupedData) since 2.0.0
0201 setMethod("gapply",
0202           signature(x = "GroupedData"),
0203           function(x, func, schema) {
0204             if (is.null(schema)) stop("schema cannot be NULL")
0205             gapplyInternal(x, func, schema)
0206           })
0207 
0208 #' gapplyCollect
0209 #'
0210 #' @rdname gapplyCollect
0211 #' @aliases gapplyCollect,GroupedData-method
0212 #' @name gapplyCollect
0213 #' @note gapplyCollect(GroupedData) since 2.0.0
0214 setMethod("gapplyCollect",
0215           signature(x = "GroupedData"),
0216           function(x, func) {
0217             gdf <- gapplyInternal(x, func, NULL)
0218             content <- callJMethod(gdf@sdf, "collect")
0219             # content is a list of items of struct type. Each item has a single field
0220             # which is a serialized data.frame corresponds to one group of the
0221             # SparkDataFrame.
0222             ldfs <- lapply(content, function(x) { unserialize(x[[1]]) })
0223             ldf <- do.call(rbind, ldfs)
0224             row.names(ldf) <- NULL
0225             ldf
0226           })
0227 
0228 gapplyInternal <- function(x, func, schema) {
0229   if (is.character(schema)) {
0230     schema <- structType(schema)
0231   }
0232   arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.sparkr.enabled")[[1]] == "true"
0233   if (arrowEnabled) {
0234     if (inherits(schema, "structType")) {
0235       checkSchemaInArrow(schema)
0236     } else if (is.null(schema)) {
0237       stop("Arrow optimization does not support 'gapplyCollect' yet. Please disable ",
0238            "Arrow optimization or use 'collect' and 'gapply' APIs instead.")
0239     } else {
0240       stop("'schema' should be DDL-formatted string or structType.")
0241     }
0242   }
0243 
0244   packageNamesArr <- serialize(.sparkREnv[[".packages"]],
0245                        connection = NULL)
0246   broadcastArr <- lapply(ls(.broadcastNames),
0247                     function(name) { get(name, .broadcastNames) })
0248   sdf <- callJStatic(
0249            "org.apache.spark.sql.api.r.SQLUtils",
0250            "gapply",
0251            x@sgd,
0252            serialize(cleanClosure(func), connection = NULL),
0253            packageNamesArr,
0254            broadcastArr,
0255            if (class(schema) == "structType") { schema$jobj } else { NULL })
0256   dataFrame(sdf)
0257 }