0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements. See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License. You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017
0018 # mllib_regression.R: Provides methods for MLlib regression algorithms
0019 # (except for tree-based algorithms) integration
0020
0021 #' S4 class that represents a AFTSurvivalRegressionModel
0022 #'
0023 #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
0024 #' @note AFTSurvivalRegressionModel since 2.0.0
0025 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
0026
0027 #' S4 class that represents a generalized linear model
0028 #'
0029 #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
0030 #' @note GeneralizedLinearRegressionModel since 2.0.0
0031 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
0032
0033 #' S4 class that represents an IsotonicRegressionModel
0034 #'
0035 #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel
0036 #' @note IsotonicRegressionModel since 2.1.0
0037 setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
0038
0039 #' Generalized Linear Models
0040 #'
0041 #' Fits generalized linear model against a SparkDataFrame.
0042 #' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
0043 #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
0044 #'
0045 #' @param data a SparkDataFrame for training.
0046 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0047 #' operators are supported, including '~', '.', ':', '+', '-', '*', and '^'.
0048 #' @param family a description of the error distribution and link function to be used in the model.
0049 #' This can be a character string naming a family function, a family function or
0050 #' the result of a call to a family function. Refer R family at
0051 #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
0052 #' Currently these families are supported: \code{binomial}, \code{gaussian},
0053 #' \code{Gamma}, \code{poisson} and \code{tweedie}.
0054 #'
0055 #' Note that there are two ways to specify the tweedie family.
0056 #' \itemize{
0057 #' \item Set \code{family = "tweedie"} and specify the var.power and link.power;
0058 #' \item When package \code{statmod} is loaded, the tweedie family is specified
0059 #' using the family definition therein, i.e., \code{tweedie(var.power, link.power)}.
0060 #' }
0061 #' @param tol positive convergence tolerance of iterations.
0062 #' @param maxIter integer giving the maximal number of IRLS iterations.
0063 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
0064 #' weights as 1.0.
0065 #' @param regParam regularization parameter for L2 regularization.
0066 #' @param var.power the power in the variance function of the Tweedie distribution which provides
0067 #' the relationship between the variance and mean of the distribution. Only
0068 #' applicable to the Tweedie family.
0069 #' @param link.power the index in the power link function. Only applicable to the Tweedie family.
0070 #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
0071 #' decide the base level of a string feature as the last category
0072 #' after ordering is dropped when encoding strings. Supported options
0073 #' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
0074 #' "alphabetAsc". The default value is "frequencyDesc". When the
0075 #' ordering is set to "alphabetDesc", this drops the same category
0076 #' as R when encoding strings.
0077 #' @param offsetCol the offset column name. If this is not set or empty, we treat all instance
0078 #' offsets as 0.0. The feature specified as offset has a constant coefficient of
0079 #' 1.0.
0080 #' @param ... additional arguments passed to the method.
0081 #' @aliases spark.glm,SparkDataFrame,formula-method
0082 #' @return \code{spark.glm} returns a fitted generalized linear model.
0083 #' @rdname spark.glm
0084 #' @name spark.glm
0085 #' @examples
0086 #' \dontrun{
0087 #' sparkR.session()
0088 #' t <- as.data.frame(Titanic, stringsAsFactors = FALSE)
0089 #' df <- createDataFrame(t)
0090 #' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian")
0091 #' summary(model)
0092 #'
0093 #' # fitted values on training data
0094 #' fitted <- predict(model, df)
0095 #' head(select(fitted, "Freq", "prediction"))
0096 #'
0097 #' # save fitted model to input path
0098 #' path <- "path/to/model"
0099 #' write.ml(model, path)
0100 #'
0101 #' # can also read back the saved model and print
0102 #' savedModel <- read.ml(path)
0103 #' summary(savedModel)
0104 #'
0105 #' # note that the default string encoding is different from R's glm
0106 #' model2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = t)
0107 #' summary(model2)
0108 #' # use stringIndexerOrderType = "alphabetDesc" to force string encoding
0109 #' # to be consistent with R
0110 #' model3 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian",
0111 #' stringIndexerOrderType = "alphabetDesc")
0112 #' summary(model3)
0113 #'
0114 #' # fit tweedie model
0115 #' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie",
0116 #' var.power = 1.2, link.power = 0)
0117 #' summary(model)
0118 #'
0119 #' # use the tweedie family from statmod
0120 #' library(statmod)
0121 #' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0))
0122 #' summary(model)
0123 #' }
0124 #' @note spark.glm since 2.0.0
0125 #' @seealso \link{glm}, \link{read.ml}
0126 setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
0127 function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
0128 regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power,
0129 stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
0130 "alphabetDesc", "alphabetAsc"),
0131 offsetCol = NULL) {
0132
0133 stringIndexerOrderType <- match.arg(stringIndexerOrderType)
0134 if (is.character(family)) {
0135 # Handle when family = "tweedie"
0136 if (tolower(family) == "tweedie") {
0137 family <- list(family = "tweedie", link = NULL)
0138 } else {
0139 family <- get(family, mode = "function", envir = parent.frame())
0140 }
0141 }
0142 if (is.function(family)) {
0143 family <- family()
0144 }
0145 if (is.null(family$family)) {
0146 print(family)
0147 stop("'family' not recognized")
0148 }
0149 # Handle when family = statmod::tweedie()
0150 if (tolower(family$family) == "tweedie" && !is.null(family$variance)) {
0151 var.power <- log(family$variance(exp(1)))
0152 link.power <- log(family$linkfun(exp(1)))
0153 family <- list(family = "tweedie", link = NULL)
0154 }
0155
0156 formula <- paste(deparse(formula), collapse = "")
0157 if (!is.null(weightCol) && weightCol == "") {
0158 weightCol <- NULL
0159 } else if (!is.null(weightCol)) {
0160 weightCol <- as.character(weightCol)
0161 }
0162
0163 if (!is.null(offsetCol)) {
0164 offsetCol <- as.character(offsetCol)
0165 if (nchar(offsetCol) == 0) {
0166 offsetCol <- NULL
0167 }
0168 }
0169
0170 # For known families, Gamma is upper-cased
0171 jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
0172 "fit", formula, data@sdf, tolower(family$family), family$link,
0173 tol, as.integer(maxIter), weightCol, regParam,
0174 as.double(var.power), as.double(link.power),
0175 stringIndexerOrderType, offsetCol)
0176 new("GeneralizedLinearRegressionModel", jobj = jobj)
0177 })
0178
0179 #' Generalized Linear Models (R-compliant)
0180 #'
0181 #' Fits a generalized linear model, similarly to R's glm().
0182 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0183 #' operators are supported, including '~', '.', ':', '+', and '-'.
0184 #' @param data a SparkDataFrame or R's glm data for training.
0185 #' @param family a description of the error distribution and link function to be used in the model.
0186 #' This can be a character string naming a family function, a family function or
0187 #' the result of a call to a family function. Refer R family at
0188 #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
0189 #' Currently these families are supported: \code{binomial}, \code{gaussian},
0190 #' \code{poisson}, \code{Gamma}, and \code{tweedie}.
0191 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
0192 #' weights as 1.0.
0193 #' @param epsilon positive convergence tolerance of iterations.
0194 #' @param maxit integer giving the maximal number of IRLS iterations.
0195 #' @param var.power the index of the power variance function in the Tweedie family.
0196 #' @param link.power the index of the power link function in the Tweedie family.
0197 #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
0198 #' decide the base level of a string feature as the last category
0199 #' after ordering is dropped when encoding strings. Supported options
0200 #' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
0201 #' "alphabetAsc". The default value is "frequencyDesc". When the
0202 #' ordering is set to "alphabetDesc", this drops the same category
0203 #' as R when encoding strings.
0204 #' @param offsetCol the offset column name. If this is not set or empty, we treat all instance
0205 #' offsets as 0.0. The feature specified as offset has a constant coefficient of
0206 #' 1.0.
0207 #' @return \code{glm} returns a fitted generalized linear model.
0208 #' @rdname glm
0209 #' @aliases glm
0210 #' @examples
0211 #' \dontrun{
0212 #' sparkR.session()
0213 #' t <- as.data.frame(Titanic)
0214 #' df <- createDataFrame(t)
0215 #' model <- glm(Freq ~ Sex + Age, df, family = "gaussian")
0216 #' summary(model)
0217 #' }
0218 #' @note glm since 1.5.0
0219 #' @seealso \link{spark.glm}
0220 setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
0221 function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL,
0222 var.power = 0.0, link.power = 1.0 - var.power,
0223 stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
0224 "alphabetDesc", "alphabetAsc"),
0225 offsetCol = NULL) {
0226 spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol,
0227 var.power = var.power, link.power = link.power,
0228 stringIndexerOrderType = stringIndexerOrderType,
0229 offsetCol = offsetCol)
0230 })
0231
0232 # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
0233
0234 #' @param object a fitted generalized linear model.
0235 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0236 #' The list of components includes at least the \code{coefficients} (coefficients matrix,
0237 #' which includes coefficients, standard error of coefficients, t value and p value),
0238 #' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC)
0239 #' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in
0240 #' the data, the coefficients matrix only provides coefficients.
0241 #' @rdname spark.glm
0242 #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0
0243 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
0244 function(object) {
0245 jobj <- object@jobj
0246 is.loaded <- callJMethod(jobj, "isLoaded")
0247 features <- callJMethod(jobj, "rFeatures")
0248 coefficients <- callJMethod(jobj, "rCoefficients")
0249 dispersion <- callJMethod(jobj, "rDispersion")
0250 null.deviance <- callJMethod(jobj, "rNullDeviance")
0251 deviance <- callJMethod(jobj, "rDeviance")
0252 df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull")
0253 df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom")
0254 iter <- callJMethod(jobj, "rNumIterations")
0255 family <- callJMethod(jobj, "rFamily")
0256 aic <- callJMethod(jobj, "rAic")
0257 if (family == "tweedie" && aic == 0) aic <- NA
0258 deviance.resid <- if (is.loaded) {
0259 NULL
0260 } else {
0261 dataFrame(callJMethod(jobj, "rDevianceResiduals"))
0262 }
0263 # If the underlying WeightedLeastSquares using "normal" solver, we can provide
0264 # coefficients, standard error of coefficients, t value and p value. Otherwise,
0265 # it will be fitted by local "l-bfgs", we can only provide coefficients.
0266 if (length(features) == length(coefficients)) {
0267 coefficients <- matrix(unlist(coefficients), ncol = 1)
0268 colnames(coefficients) <- c("Estimate")
0269 rownames(coefficients) <- unlist(features)
0270 } else {
0271 coefficients <- matrix(unlist(coefficients), ncol = 4)
0272 colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
0273 rownames(coefficients) <- unlist(features)
0274 }
0275 ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
0276 dispersion = dispersion, null.deviance = null.deviance,
0277 deviance = deviance, df.null = df.null, df.residual = df.residual,
0278 aic = aic, iter = iter, family = family, is.loaded = is.loaded)
0279 class(ans) <- "summary.GeneralizedLinearRegressionModel"
0280 ans
0281 })
0282
0283 # Prints the summary of GeneralizedLinearRegressionModel
0284
0285 #' @rdname spark.glm
0286 #' @param x summary object of fitted generalized linear model returned by \code{summary} function.
0287 #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0
0288 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
0289 if (x$is.loaded) {
0290 cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n")
0291 } else {
0292 x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
0293 c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
0294 x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
0295 cat("\nDeviance Residuals: \n")
0296 cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
0297 print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
0298 }
0299
0300 cat("\nCoefficients:\n")
0301 print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
0302
0303 cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion),
0304 ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"),
0305 format(unlist(x[c("null.deviance", "deviance")]), digits = 5L),
0306 " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"),
0307 1L, paste, collapse = " "), sep = "")
0308 cat("AIC: ", format(x$aic, digits = 4L), "\n\n",
0309 "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "")
0310 invisible(x)
0311 }
0312
0313 # Makes predictions from a generalized linear model produced by glm() or spark.glm(),
0314 # similarly to R's predict().
0315
0316 #' @param newData a SparkDataFrame for testing.
0317 #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
0318 #' "prediction".
0319 #' @rdname spark.glm
0320 #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0
0321 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
0322 function(object, newData) {
0323 predict_internal(object, newData)
0324 })
0325
0326 # Saves the generalized linear model to the input path.
0327
0328 #' @param path the directory where the model is saved.
0329 #' @param overwrite overwrites or not if the output path already exists. Default is FALSE
0330 #' which means throw exception if the output path exists.
0331 #'
0332 #' @rdname spark.glm
0333 #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0
0334 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
0335 function(object, path, overwrite = FALSE) {
0336 write_internal(object, path, overwrite)
0337 })
0338
0339 #' Isotonic Regression Model
0340 #'
0341 #' Fits an Isotonic Regression model against a SparkDataFrame, similarly to R's isoreg().
0342 #' Users can print, make predictions on the produced model and save the model to the input path.
0343 #'
0344 #' @param data SparkDataFrame for training.
0345 #' @param formula A symbolic description of the model to be fitted. Currently only a few formula
0346 #' operators are supported, including '~', '.', ':', '+', and '-'.
0347 #' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or
0348 #' antitonic/decreasing (FALSE).
0349 #' @param featureIndex The index of the feature if \code{featuresCol} is a vector column
0350 #' (default: 0), no effect otherwise.
0351 #' @param weightCol The weight column name.
0352 #' @param ... additional arguments passed to the method.
0353 #' @return \code{spark.isoreg} returns a fitted Isotonic Regression model.
0354 #' @rdname spark.isoreg
0355 #' @aliases spark.isoreg,SparkDataFrame,formula-method
0356 #' @name spark.isoreg
0357 #' @examples
0358 #' \dontrun{
0359 #' sparkR.session()
0360 #' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0),
0361 #' list(5.0, 3.0), list(1.0, 4.0))
0362 #' df <- createDataFrame(data, c("label", "feature"))
0363 #' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE)
0364 #' # return model boundaries and prediction as lists
0365 #' result <- summary(model, df)
0366 #' # prediction based on fitted model
0367 #' predict_data <- list(list(-2.0), list(-1.0), list(0.5),
0368 #' list(0.75), list(1.0), list(2.0), list(9.0))
0369 #' predict_df <- createDataFrame(predict_data, c("feature"))
0370 #' # get prediction column
0371 #' predict_result <- collect(select(predict(model, predict_df), "prediction"))
0372 #'
0373 #' # save fitted model to input path
0374 #' path <- "path/to/model"
0375 #' write.ml(model, path)
0376 #'
0377 #' # can also read back the saved model and print
0378 #' savedModel <- read.ml(path)
0379 #' summary(savedModel)
0380 #' }
0381 #' @note spark.isoreg since 2.1.0
0382 setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"),
0383 function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
0384 formula <- paste(deparse(formula), collapse = "")
0385
0386 if (!is.null(weightCol) && weightCol == "") {
0387 weightCol <- NULL
0388 } else if (!is.null(weightCol)) {
0389 weightCol <- as.character(weightCol)
0390 }
0391
0392 jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
0393 data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
0394 weightCol)
0395 new("IsotonicRegressionModel", jobj = jobj)
0396 })
0397
0398 # Get the summary of an IsotonicRegressionModel model
0399
0400 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0401 #' The list includes model's \code{boundaries} (boundaries in increasing order)
0402 #' and \code{predictions} (predictions associated with the boundaries at the same index).
0403 #' @rdname spark.isoreg
0404 #' @aliases summary,IsotonicRegressionModel-method
0405 #' @note summary(IsotonicRegressionModel) since 2.1.0
0406 setMethod("summary", signature(object = "IsotonicRegressionModel"),
0407 function(object) {
0408 jobj <- object@jobj
0409 boundaries <- callJMethod(jobj, "boundaries")
0410 predictions <- callJMethod(jobj, "predictions")
0411 list(boundaries = boundaries, predictions = predictions)
0412 })
0413
0414 # Predicted values based on an isotonicRegression model
0415
0416 #' @param object a fitted IsotonicRegressionModel.
0417 #' @param newData SparkDataFrame for testing.
0418 #' @return \code{predict} returns a SparkDataFrame containing predicted values.
0419 #' @rdname spark.isoreg
0420 #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method
0421 #' @note predict(IsotonicRegressionModel) since 2.1.0
0422 setMethod("predict", signature(object = "IsotonicRegressionModel"),
0423 function(object, newData) {
0424 predict_internal(object, newData)
0425 })
0426
0427 # Save fitted IsotonicRegressionModel to the input path
0428
0429 #' @param path The directory where the model is saved.
0430 #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
0431 #' which means throw exception if the output path exists.
0432 #'
0433 #' @rdname spark.isoreg
0434 #' @aliases write.ml,IsotonicRegressionModel,character-method
0435 #' @note write.ml(IsotonicRegression, character) since 2.1.0
0436 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"),
0437 function(object, path, overwrite = FALSE) {
0438 write_internal(object, path, overwrite)
0439 })
0440
0441 #' Accelerated Failure Time (AFT) Survival Regression Model
0442 #'
0443 #' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on
0444 #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model,
0445 #' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
0446 #' save/load fitted models.
0447 #'
0448 #' @param data a SparkDataFrame for training.
0449 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0450 #' operators are supported, including '~', ':', '+', and '-'.
0451 #' Note that operator '.' is not supported currently.
0452 #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the
0453 #' dimensions of features or the number of partitions are large, this
0454 #' param could be adjusted to a larger size. This is an expert parameter.
0455 #' Default value should be good for most cases.
0456 #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
0457 #' decide the base level of a string feature as the last category
0458 #' after ordering is dropped when encoding strings. Supported options
0459 #' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
0460 #' "alphabetAsc". The default value is "frequencyDesc". When the
0461 #' ordering is set to "alphabetDesc", this drops the same category
0462 #' as R when encoding strings.
0463 #' @param ... additional arguments passed to the method.
0464 #' @return \code{spark.survreg} returns a fitted AFT survival regression model.
0465 #' @rdname spark.survreg
0466 #' @seealso survival: \url{https://cran.r-project.org/package=survival}
0467 #' @examples
0468 #' \dontrun{
0469 #' df <- createDataFrame(ovarian)
0470 #' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx)
0471 #'
0472 #' # get a summary of the model
0473 #' summary(model)
0474 #'
0475 #' # make predictions
0476 #' predicted <- predict(model, df)
0477 #' showDF(predicted)
0478 #'
0479 #' # save and load the model
0480 #' path <- "path/to/model"
0481 #' write.ml(model, path)
0482 #' savedModel <- read.ml(path)
0483 #' summary(savedModel)
0484 #' }
0485 #' @note spark.survreg since 2.0.0
0486 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
0487 function(data, formula, aggregationDepth = 2,
0488 stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
0489 "alphabetDesc", "alphabetAsc")) {
0490 stringIndexerOrderType <- match.arg(stringIndexerOrderType)
0491 formula <- paste(deparse(formula), collapse = "")
0492 jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
0493 "fit", formula, data@sdf, as.integer(aggregationDepth),
0494 stringIndexerOrderType)
0495 new("AFTSurvivalRegressionModel", jobj = jobj)
0496 })
0497
0498 # Returns a summary of the AFT survival regression model produced by spark.survreg,
0499 # similarly to R's summary().
0500
0501 #' @param object a fitted AFT survival regression model.
0502 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0503 #' The list includes the model's \code{coefficients} (features, coefficients,
0504 #' intercept and log(scale)).
0505 #' @rdname spark.survreg
0506 #' @note summary(AFTSurvivalRegressionModel) since 2.0.0
0507 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
0508 function(object) {
0509 jobj <- object@jobj
0510 features <- callJMethod(jobj, "rFeatures")
0511 coefficients <- callJMethod(jobj, "rCoefficients")
0512 coefficients <- as.matrix(unlist(coefficients))
0513 colnames(coefficients) <- c("Value")
0514 rownames(coefficients) <- unlist(features)
0515 list(coefficients = coefficients)
0516 })
0517
0518 # Makes predictions from an AFT survival regression model or a model produced by
0519 # spark.survreg, similarly to R package survival's predict.
0520
0521 #' @param newData a SparkDataFrame for testing.
0522 #' @return \code{predict} returns a SparkDataFrame containing predicted values
0523 #' on the original scale of the data (mean predicted value at scale = 1.0).
0524 #' @rdname spark.survreg
0525 #' @note predict(AFTSurvivalRegressionModel) since 2.0.0
0526 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
0527 function(object, newData) {
0528 predict_internal(object, newData)
0529 })
0530
0531 # Saves the AFT survival regression model to the input path.
0532
0533 #' @param path the directory where the model is saved.
0534 #' @param overwrite overwrites or not if the output path already exists. Default is FALSE
0535 #' which means throw exception if the output path exists.
0536 #' @rdname spark.survreg
0537 #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0
0538 #' @seealso \link{write.ml}
0539 setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"),
0540 function(object, path, overwrite = FALSE) {
0541 write_internal(object, path, overwrite)
0542 })