Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # mllib_regression.R: Provides methods for MLlib regression algorithms
0019 #                     (except for tree-based algorithms) integration
0020 
0021 #' S4 class that represents a AFTSurvivalRegressionModel
0022 #'
0023 #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
0024 #' @note AFTSurvivalRegressionModel since 2.0.0
0025 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
0026 
0027 #' S4 class that represents a generalized linear model
0028 #'
0029 #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
0030 #' @note GeneralizedLinearRegressionModel since 2.0.0
0031 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
0032 
0033 #' S4 class that represents an IsotonicRegressionModel
0034 #'
0035 #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel
0036 #' @note IsotonicRegressionModel since 2.1.0
0037 setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
0038 
0039 #' Generalized Linear Models
0040 #'
0041 #' Fits generalized linear model against a SparkDataFrame.
0042 #' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
0043 #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
0044 #'
0045 #' @param data a SparkDataFrame for training.
0046 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0047 #'                operators are supported, including '~', '.', ':', '+', '-', '*', and '^'.
0048 #' @param family a description of the error distribution and link function to be used in the model.
0049 #'               This can be a character string naming a family function, a family function or
0050 #'               the result of a call to a family function. Refer R family at
0051 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
0052 #'               Currently these families are supported: \code{binomial}, \code{gaussian},
0053 #'               \code{Gamma}, \code{poisson} and \code{tweedie}.
0054 #'
0055 #'               Note that there are two ways to specify the tweedie family.
0056 #'               \itemize{
0057 #'                \item Set \code{family = "tweedie"} and specify the var.power and link.power;
0058 #'                \item When package \code{statmod} is loaded, the tweedie family is specified
0059 #'                using the family definition therein, i.e., \code{tweedie(var.power, link.power)}.
0060 #'               }
0061 #' @param tol positive convergence tolerance of iterations.
0062 #' @param maxIter integer giving the maximal number of IRLS iterations.
0063 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
0064 #'                  weights as 1.0.
0065 #' @param regParam regularization parameter for L2 regularization.
0066 #' @param var.power the power in the variance function of the Tweedie distribution which provides
0067 #'                      the relationship between the variance and mean of the distribution. Only
0068 #'                      applicable to the Tweedie family.
0069 #' @param link.power the index in the power link function. Only applicable to the Tweedie family.
0070 #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
0071 #'                               decide the base level of a string feature as the last category
0072 #'                               after ordering is dropped when encoding strings. Supported options
0073 #'                               are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
0074 #'                               "alphabetAsc". The default value is "frequencyDesc". When the
0075 #'                               ordering is set to "alphabetDesc", this drops the same category
0076 #'                               as R when encoding strings.
0077 #' @param offsetCol the offset column name. If this is not set or empty, we treat all instance
0078 #'                  offsets as 0.0. The feature specified as offset has a constant coefficient of
0079 #'                  1.0.
0080 #' @param ... additional arguments passed to the method.
0081 #' @aliases spark.glm,SparkDataFrame,formula-method
0082 #' @return \code{spark.glm} returns a fitted generalized linear model.
0083 #' @rdname spark.glm
0084 #' @name spark.glm
0085 #' @examples
0086 #' \dontrun{
0087 #' sparkR.session()
0088 #' t <- as.data.frame(Titanic, stringsAsFactors = FALSE)
0089 #' df <- createDataFrame(t)
0090 #' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian")
0091 #' summary(model)
0092 #'
0093 #' # fitted values on training data
0094 #' fitted <- predict(model, df)
0095 #' head(select(fitted, "Freq", "prediction"))
0096 #'
0097 #' # save fitted model to input path
0098 #' path <- "path/to/model"
0099 #' write.ml(model, path)
0100 #'
0101 #' # can also read back the saved model and print
0102 #' savedModel <- read.ml(path)
0103 #' summary(savedModel)
0104 #'
0105 #' # note that the default string encoding is different from R's glm
0106 #' model2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = t)
0107 #' summary(model2)
0108 #' # use stringIndexerOrderType = "alphabetDesc" to force string encoding
0109 #' # to be consistent with R
0110 #' model3 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian",
0111 #'                    stringIndexerOrderType = "alphabetDesc")
0112 #' summary(model3)
0113 #'
0114 #' # fit tweedie model
0115 #' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie",
0116 #'                    var.power = 1.2, link.power = 0)
0117 #' summary(model)
0118 #'
0119 #' # use the tweedie family from statmod
0120 #' library(statmod)
0121 #' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0))
0122 #' summary(model)
0123 #' }
0124 #' @note spark.glm since 2.0.0
0125 #' @seealso \link{glm}, \link{read.ml}
0126 setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
0127           function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
0128                    regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power,
0129                    stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
0130                                               "alphabetDesc", "alphabetAsc"),
0131                    offsetCol = NULL) {
0132 
0133             stringIndexerOrderType <- match.arg(stringIndexerOrderType)
0134             if (is.character(family)) {
0135               # Handle when family = "tweedie"
0136               if (tolower(family) == "tweedie") {
0137                 family <- list(family = "tweedie", link = NULL)
0138               } else {
0139                 family <- get(family, mode = "function", envir = parent.frame())
0140               }
0141             }
0142             if (is.function(family)) {
0143               family <- family()
0144             }
0145             if (is.null(family$family)) {
0146               print(family)
0147               stop("'family' not recognized")
0148             }
0149             # Handle when family = statmod::tweedie()
0150             if (tolower(family$family) == "tweedie" && !is.null(family$variance)) {
0151               var.power <- log(family$variance(exp(1)))
0152               link.power <- log(family$linkfun(exp(1)))
0153               family <- list(family = "tweedie", link = NULL)
0154             }
0155 
0156             formula <- paste(deparse(formula), collapse = "")
0157             if (!is.null(weightCol) && weightCol == "") {
0158               weightCol <- NULL
0159             } else if (!is.null(weightCol)) {
0160               weightCol <- as.character(weightCol)
0161             }
0162 
0163             if (!is.null(offsetCol)) {
0164               offsetCol <- as.character(offsetCol)
0165               if (nchar(offsetCol) == 0) {
0166                 offsetCol <- NULL
0167               }
0168             }
0169 
0170             # For known families, Gamma is upper-cased
0171             jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
0172                                 "fit", formula, data@sdf, tolower(family$family), family$link,
0173                                 tol, as.integer(maxIter), weightCol, regParam,
0174                                 as.double(var.power), as.double(link.power),
0175                                 stringIndexerOrderType, offsetCol)
0176             new("GeneralizedLinearRegressionModel", jobj = jobj)
0177           })
0178 
0179 #' Generalized Linear Models (R-compliant)
0180 #'
0181 #' Fits a generalized linear model, similarly to R's glm().
0182 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0183 #'                operators are supported, including '~', '.', ':', '+', and '-'.
0184 #' @param data a SparkDataFrame or R's glm data for training.
0185 #' @param family a description of the error distribution and link function to be used in the model.
0186 #'               This can be a character string naming a family function, a family function or
0187 #'               the result of a call to a family function. Refer R family at
0188 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
0189 #'               Currently these families are supported: \code{binomial}, \code{gaussian},
0190 #'               \code{poisson}, \code{Gamma}, and \code{tweedie}.
0191 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
0192 #'                  weights as 1.0.
0193 #' @param epsilon positive convergence tolerance of iterations.
0194 #' @param maxit integer giving the maximal number of IRLS iterations.
0195 #' @param var.power the index of the power variance function in the Tweedie family.
0196 #' @param link.power the index of the power link function in the Tweedie family.
0197 #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
0198 #'                               decide the base level of a string feature as the last category
0199 #'                               after ordering is dropped when encoding strings. Supported options
0200 #'                               are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
0201 #'                               "alphabetAsc". The default value is "frequencyDesc". When the
0202 #'                               ordering is set to "alphabetDesc", this drops the same category
0203 #'                               as R when encoding strings.
0204 #' @param offsetCol the offset column name. If this is not set or empty, we treat all instance
0205 #'                  offsets as 0.0. The feature specified as offset has a constant coefficient of
0206 #'                  1.0.
0207 #' @return \code{glm} returns a fitted generalized linear model.
0208 #' @rdname glm
0209 #' @aliases glm
0210 #' @examples
0211 #' \dontrun{
0212 #' sparkR.session()
0213 #' t <- as.data.frame(Titanic)
0214 #' df <- createDataFrame(t)
0215 #' model <- glm(Freq ~ Sex + Age, df, family = "gaussian")
0216 #' summary(model)
0217 #' }
0218 #' @note glm since 1.5.0
0219 #' @seealso \link{spark.glm}
0220 setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
0221           function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL,
0222                    var.power = 0.0, link.power = 1.0 - var.power,
0223                    stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
0224                                               "alphabetDesc", "alphabetAsc"),
0225                    offsetCol = NULL) {
0226             spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol,
0227                       var.power = var.power, link.power = link.power,
0228                       stringIndexerOrderType = stringIndexerOrderType,
0229                       offsetCol = offsetCol)
0230           })
0231 
0232 #  Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
0233 
0234 #' @param object a fitted generalized linear model.
0235 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0236 #'         The list of components includes at least the \code{coefficients} (coefficients matrix,
0237 #'         which includes coefficients, standard error of coefficients, t value and p value),
0238 #'         \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC)
0239 #'         and \code{iter} (number of iterations IRLS takes). If there are collinear columns in
0240 #'         the data, the coefficients matrix only provides coefficients.
0241 #' @rdname spark.glm
0242 #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0
0243 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
0244           function(object) {
0245             jobj <- object@jobj
0246             is.loaded <- callJMethod(jobj, "isLoaded")
0247             features <- callJMethod(jobj, "rFeatures")
0248             coefficients <- callJMethod(jobj, "rCoefficients")
0249             dispersion <- callJMethod(jobj, "rDispersion")
0250             null.deviance <- callJMethod(jobj, "rNullDeviance")
0251             deviance <- callJMethod(jobj, "rDeviance")
0252             df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull")
0253             df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom")
0254             iter <- callJMethod(jobj, "rNumIterations")
0255             family <- callJMethod(jobj, "rFamily")
0256             aic <- callJMethod(jobj, "rAic")
0257             if (family == "tweedie" && aic == 0) aic <- NA
0258             deviance.resid <- if (is.loaded) {
0259               NULL
0260             } else {
0261               dataFrame(callJMethod(jobj, "rDevianceResiduals"))
0262             }
0263             # If the underlying WeightedLeastSquares using "normal" solver, we can provide
0264             # coefficients, standard error of coefficients, t value and p value. Otherwise,
0265             # it will be fitted by local "l-bfgs", we can only provide coefficients.
0266             if (length(features) == length(coefficients)) {
0267               coefficients <- matrix(unlist(coefficients), ncol = 1)
0268               colnames(coefficients) <- c("Estimate")
0269               rownames(coefficients) <- unlist(features)
0270             } else {
0271               coefficients <- matrix(unlist(coefficients), ncol = 4)
0272               colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
0273               rownames(coefficients) <- unlist(features)
0274             }
0275             ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
0276                         dispersion = dispersion, null.deviance = null.deviance,
0277                         deviance = deviance, df.null = df.null, df.residual = df.residual,
0278                         aic = aic, iter = iter, family = family, is.loaded = is.loaded)
0279             class(ans) <- "summary.GeneralizedLinearRegressionModel"
0280             ans
0281           })
0282 
0283 #  Prints the summary of GeneralizedLinearRegressionModel
0284 
0285 #' @rdname spark.glm
0286 #' @param x summary object of fitted generalized linear model returned by \code{summary} function.
0287 #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0
0288 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
0289   if (x$is.loaded) {
0290     cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n")
0291   } else {
0292     x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
0293     c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
0294     x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
0295     cat("\nDeviance Residuals: \n")
0296     cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
0297     print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
0298   }
0299 
0300   cat("\nCoefficients:\n")
0301   print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
0302 
0303   cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion),
0304     ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"),
0305     format(unlist(x[c("null.deviance", "deviance")]), digits = 5L),
0306     " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"),
0307     1L, paste, collapse = " "), sep = "")
0308   cat("AIC: ", format(x$aic, digits = 4L), "\n\n",
0309     "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "")
0310   invisible(x)
0311   }
0312 
0313 #  Makes predictions from a generalized linear model produced by glm() or spark.glm(),
0314 #  similarly to R's predict().
0315 
0316 #' @param newData a SparkDataFrame for testing.
0317 #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
0318 #'         "prediction".
0319 #' @rdname spark.glm
0320 #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0
0321 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
0322           function(object, newData) {
0323             predict_internal(object, newData)
0324           })
0325 
0326 #  Saves the generalized linear model to the input path.
0327 
0328 #' @param path the directory where the model is saved.
0329 #' @param overwrite overwrites or not if the output path already exists. Default is FALSE
0330 #'                  which means throw exception if the output path exists.
0331 #'
0332 #' @rdname spark.glm
0333 #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0
0334 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
0335           function(object, path, overwrite = FALSE) {
0336             write_internal(object, path, overwrite)
0337           })
0338 
0339 #' Isotonic Regression Model
0340 #'
0341 #' Fits an Isotonic Regression model against a SparkDataFrame, similarly to R's isoreg().
0342 #' Users can print, make predictions on the produced model and save the model to the input path.
0343 #'
0344 #' @param data SparkDataFrame for training.
0345 #' @param formula A symbolic description of the model to be fitted. Currently only a few formula
0346 #'                operators are supported, including '~', '.', ':', '+', and '-'.
0347 #' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or
0348 #'                 antitonic/decreasing (FALSE).
0349 #' @param featureIndex The index of the feature if \code{featuresCol} is a vector column
0350 #'                     (default: 0), no effect otherwise.
0351 #' @param weightCol The weight column name.
0352 #' @param ... additional arguments passed to the method.
0353 #' @return \code{spark.isoreg} returns a fitted Isotonic Regression model.
0354 #' @rdname spark.isoreg
0355 #' @aliases spark.isoreg,SparkDataFrame,formula-method
0356 #' @name spark.isoreg
0357 #' @examples
0358 #' \dontrun{
0359 #' sparkR.session()
0360 #' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0),
0361 #'         list(5.0, 3.0), list(1.0, 4.0))
0362 #' df <- createDataFrame(data, c("label", "feature"))
0363 #' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE)
0364 #' # return model boundaries and prediction as lists
0365 #' result <- summary(model, df)
0366 #' # prediction based on fitted model
0367 #' predict_data <- list(list(-2.0), list(-1.0), list(0.5),
0368 #'                 list(0.75), list(1.0), list(2.0), list(9.0))
0369 #' predict_df <- createDataFrame(predict_data, c("feature"))
0370 #' # get prediction column
0371 #' predict_result <- collect(select(predict(model, predict_df), "prediction"))
0372 #'
0373 #' # save fitted model to input path
0374 #' path <- "path/to/model"
0375 #' write.ml(model, path)
0376 #'
0377 #' # can also read back the saved model and print
0378 #' savedModel <- read.ml(path)
0379 #' summary(savedModel)
0380 #' }
0381 #' @note spark.isoreg since 2.1.0
0382 setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"),
0383           function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
0384             formula <- paste(deparse(formula), collapse = "")
0385 
0386             if (!is.null(weightCol) && weightCol == "") {
0387               weightCol <- NULL
0388             } else if (!is.null(weightCol)) {
0389               weightCol <- as.character(weightCol)
0390             }
0391 
0392             jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
0393                                 data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
0394                                 weightCol)
0395             new("IsotonicRegressionModel", jobj = jobj)
0396           })
0397 
0398 #  Get the summary of an IsotonicRegressionModel model
0399 
0400 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0401 #'         The list includes model's \code{boundaries} (boundaries in increasing order)
0402 #'         and \code{predictions} (predictions associated with the boundaries at the same index).
0403 #' @rdname spark.isoreg
0404 #' @aliases summary,IsotonicRegressionModel-method
0405 #' @note summary(IsotonicRegressionModel) since 2.1.0
0406 setMethod("summary", signature(object = "IsotonicRegressionModel"),
0407           function(object) {
0408             jobj <- object@jobj
0409             boundaries <- callJMethod(jobj, "boundaries")
0410             predictions <- callJMethod(jobj, "predictions")
0411             list(boundaries = boundaries, predictions = predictions)
0412           })
0413 
0414 #  Predicted values based on an isotonicRegression model
0415 
0416 #' @param object a fitted IsotonicRegressionModel.
0417 #' @param newData SparkDataFrame for testing.
0418 #' @return \code{predict} returns a SparkDataFrame containing predicted values.
0419 #' @rdname spark.isoreg
0420 #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method
0421 #' @note predict(IsotonicRegressionModel) since 2.1.0
0422 setMethod("predict", signature(object = "IsotonicRegressionModel"),
0423           function(object, newData) {
0424             predict_internal(object, newData)
0425           })
0426 
0427 #  Save fitted IsotonicRegressionModel to the input path
0428 
0429 #' @param path The directory where the model is saved.
0430 #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
0431 #'                  which means throw exception if the output path exists.
0432 #'
0433 #' @rdname spark.isoreg
0434 #' @aliases write.ml,IsotonicRegressionModel,character-method
0435 #' @note write.ml(IsotonicRegression, character) since 2.1.0
0436 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"),
0437           function(object, path, overwrite = FALSE) {
0438             write_internal(object, path, overwrite)
0439           })
0440 
0441 #' Accelerated Failure Time (AFT) Survival Regression Model
0442 #'
0443 #' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on
0444 #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model,
0445 #' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
0446 #' save/load fitted models.
0447 #'
0448 #' @param data a SparkDataFrame for training.
0449 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0450 #'                operators are supported, including '~', ':', '+', and '-'.
0451 #'                Note that operator '.' is not supported currently.
0452 #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the
0453 #'                         dimensions of features or the number of partitions are large, this
0454 #'                         param could be adjusted to a larger size. This is an expert parameter.
0455 #'                         Default value should be good for most cases.
0456 #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
0457 #'                               decide the base level of a string feature as the last category
0458 #'                               after ordering is dropped when encoding strings. Supported options
0459 #'                               are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
0460 #'                               "alphabetAsc". The default value is "frequencyDesc". When the
0461 #'                               ordering is set to "alphabetDesc", this drops the same category
0462 #'                               as R when encoding strings.
0463 #' @param ... additional arguments passed to the method.
0464 #' @return \code{spark.survreg} returns a fitted AFT survival regression model.
0465 #' @rdname spark.survreg
0466 #' @seealso survival: \url{https://cran.r-project.org/package=survival}
0467 #' @examples
0468 #' \dontrun{
0469 #' df <- createDataFrame(ovarian)
0470 #' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx)
0471 #'
0472 #' # get a summary of the model
0473 #' summary(model)
0474 #'
0475 #' # make predictions
0476 #' predicted <- predict(model, df)
0477 #' showDF(predicted)
0478 #'
0479 #' # save and load the model
0480 #' path <- "path/to/model"
0481 #' write.ml(model, path)
0482 #' savedModel <- read.ml(path)
0483 #' summary(savedModel)
0484 #' }
0485 #' @note spark.survreg since 2.0.0
0486 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
0487           function(data, formula, aggregationDepth = 2,
0488                    stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
0489                                               "alphabetDesc", "alphabetAsc")) {
0490             stringIndexerOrderType <- match.arg(stringIndexerOrderType)
0491             formula <- paste(deparse(formula), collapse = "")
0492             jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
0493                                 "fit", formula, data@sdf, as.integer(aggregationDepth),
0494                                 stringIndexerOrderType)
0495             new("AFTSurvivalRegressionModel", jobj = jobj)
0496           })
0497 
0498 #  Returns a summary of the AFT survival regression model produced by spark.survreg,
0499 #  similarly to R's summary().
0500 
0501 #' @param object a fitted AFT survival regression model.
0502 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0503 #'         The list includes the model's \code{coefficients} (features, coefficients,
0504 #'         intercept and log(scale)).
0505 #' @rdname spark.survreg
0506 #' @note summary(AFTSurvivalRegressionModel) since 2.0.0
0507 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
0508           function(object) {
0509             jobj <- object@jobj
0510             features <- callJMethod(jobj, "rFeatures")
0511             coefficients <- callJMethod(jobj, "rCoefficients")
0512             coefficients <- as.matrix(unlist(coefficients))
0513             colnames(coefficients) <- c("Value")
0514             rownames(coefficients) <- unlist(features)
0515             list(coefficients = coefficients)
0516           })
0517 
0518 #  Makes predictions from an AFT survival regression model or a model produced by
0519 #  spark.survreg, similarly to R package survival's predict.
0520 
0521 #' @param newData a SparkDataFrame for testing.
0522 #' @return \code{predict} returns a SparkDataFrame containing predicted values
0523 #'         on the original scale of the data (mean predicted value at scale = 1.0).
0524 #' @rdname spark.survreg
0525 #' @note predict(AFTSurvivalRegressionModel) since 2.0.0
0526 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
0527           function(object, newData) {
0528             predict_internal(object, newData)
0529           })
0530 
0531 #  Saves the AFT survival regression model to the input path.
0532 
0533 #' @param path the directory where the model is saved.
0534 #' @param overwrite overwrites or not if the output path already exists. Default is FALSE
0535 #'                  which means throw exception if the output path exists.
0536 #' @rdname spark.survreg
0537 #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0
0538 #' @seealso \link{write.ml}
0539 setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"),
0540           function(object, path, overwrite = FALSE) {
0541             write_internal(object, path, overwrite)
0542           })