Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # mllib_tree.R: Provides methods for MLlib tree-based algorithms integration
0019 
0020 #' S4 class that represents a GBTRegressionModel
0021 #'
0022 #' @param jobj a Java object reference to the backing Scala GBTRegressionModel
0023 #' @note GBTRegressionModel since 2.1.0
0024 setClass("GBTRegressionModel", representation(jobj = "jobj"))
0025 
0026 #' S4 class that represents a GBTClassificationModel
0027 #'
0028 #' @param jobj a Java object reference to the backing Scala GBTClassificationModel
0029 #' @note GBTClassificationModel since 2.1.0
0030 setClass("GBTClassificationModel", representation(jobj = "jobj"))
0031 
0032 #' S4 class that represents a RandomForestRegressionModel
0033 #'
0034 #' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel
0035 #' @note RandomForestRegressionModel since 2.1.0
0036 setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
0037 
0038 #' S4 class that represents a RandomForestClassificationModel
0039 #'
0040 #' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel
0041 #' @note RandomForestClassificationModel since 2.1.0
0042 setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
0043 
0044 #' S4 class that represents a DecisionTreeRegressionModel
0045 #'
0046 #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
0047 #' @note DecisionTreeRegressionModel since 2.3.0
0048 setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))
0049 
0050 #' S4 class that represents a DecisionTreeClassificationModel
0051 #'
0052 #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
0053 #' @note DecisionTreeClassificationModel since 2.3.0
0054 setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
0055 
0056 # Create the summary of a tree ensemble model (eg. Random Forest, GBT)
0057 summary.treeEnsemble <- function(model) {
0058   jobj <- model@jobj
0059   formula <- callJMethod(jobj, "formula")
0060   numFeatures <- callJMethod(jobj, "numFeatures")
0061   features <-  callJMethod(jobj, "features")
0062   featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
0063   maxDepth <- callJMethod(jobj, "maxDepth")
0064   numTrees <- callJMethod(jobj, "numTrees")
0065   treeWeights <- callJMethod(jobj, "treeWeights")
0066   list(formula = formula,
0067        numFeatures = numFeatures,
0068        features = features,
0069        featureImportances = featureImportances,
0070        maxDepth = maxDepth,
0071        numTrees = numTrees,
0072        treeWeights = treeWeights,
0073        jobj = jobj)
0074 }
0075 
0076 # Prints the summary of tree ensemble models (eg. Random Forest, GBT)
0077 print.summary.treeEnsemble <- function(x) {
0078   jobj <- x$jobj
0079   cat("Formula: ", x$formula)
0080   cat("\nNumber of features: ", x$numFeatures)
0081   cat("\nFeatures: ", unlist(x$features))
0082   cat("\nFeature importances: ", x$featureImportances)
0083   cat("\nMax Depth: ", x$maxDepth)
0084   cat("\nNumber of trees: ", x$numTrees)
0085   cat("\nTree weights: ", unlist(x$treeWeights))
0086 
0087   summaryStr <- callJMethod(jobj, "summary")
0088   cat("\n", summaryStr, "\n")
0089   invisible(x)
0090 }
0091 
0092 # Create the summary of a decision tree model
0093 summary.decisionTree <- function(model) {
0094   jobj <- model@jobj
0095   formula <- callJMethod(jobj, "formula")
0096   numFeatures <- callJMethod(jobj, "numFeatures")
0097   features <-  callJMethod(jobj, "features")
0098   featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
0099   maxDepth <- callJMethod(jobj, "maxDepth")
0100   list(formula = formula,
0101        numFeatures = numFeatures,
0102        features = features,
0103        featureImportances = featureImportances,
0104        maxDepth = maxDepth,
0105        jobj = jobj)
0106 }
0107 
0108 # Prints the summary of decision tree models
0109 print.summary.decisionTree <- function(x) {
0110   jobj <- x$jobj
0111   cat("Formula: ", x$formula)
0112   cat("\nNumber of features: ", x$numFeatures)
0113   cat("\nFeatures: ", unlist(x$features))
0114   cat("\nFeature importances: ", x$featureImportances)
0115   cat("\nMax Depth: ", x$maxDepth)
0116 
0117   summaryStr <- callJMethod(jobj, "summary")
0118   cat("\n", summaryStr, "\n")
0119   invisible(x)
0120 }
0121 
0122 #' Gradient Boosted Tree Model for Regression and Classification
0123 #'
0124 #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
0125 #' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted
0126 #' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and
0127 #' \code{write.ml}/\code{read.ml} to save/load fitted models.
0128 #' For more details, see
0129 # nolint start
0130 #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{
0131 #' GBT Regression} and
0132 #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{
0133 #' GBT Classification}
0134 # nolint end
0135 #'
0136 #' @param data a SparkDataFrame for training.
0137 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0138 #'                operators are supported, including '~', ':', '+', '-', '*', and '^'.
0139 #' @param type type of model, one of "regression" or "classification", to fit
0140 #' @param maxDepth Maximum depth of the tree (>= 0).
0141 #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
0142 #'                how to split on features at each node. More bins give higher granularity. Must be
0143 #'                >= 2 and >= number of categories in any categorical feature.
0144 #' @param maxIter Param for maximum number of iterations (>= 0).
0145 #' @param stepSize Param for Step size to be used for each iteration of optimization.
0146 #' @param lossType Loss function which GBT tries to minimize.
0147 #'                 For classification, must be "logistic". For regression, must be one of
0148 #'                 "squared" (L2) and "absolute" (L1), default is "squared".
0149 #' @param seed integer seed for random number generation.
0150 #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in
0151 #'                        range (0, 1].
0152 #' @param minInstancesPerNode Minimum number of instances each child must have after split. If a
0153 #'                            split causes the left or right child to have fewer than
0154 #'                            minInstancesPerNode, the split will be discarded as invalid. Should be
0155 #'                            >= 1.
0156 #' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
0157 #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
0158 #'                           Note: this setting will be ignored if the checkpoint directory is not
0159 #'                           set.
0160 #' @param maxMemoryInMB Maximum memory in MiB allocated to histogram aggregation.
0161 #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
0162 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
0163 #'                     can speed up training of deeper trees. Users can set how often should the
0164 #'                     cache be checkpointed or disable it by setting checkpointInterval.
0165 #' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and
0166 #'                      label column of string type in classification model.
0167 #'                      Supported options: "skip" (filter out rows with invalid data),
0168 #'                                         "error" (throw an error), "keep" (put invalid data in
0169 #'                                         a special additional bucket, at index numLabels). Default
0170 #'                                         is "error".
0171 #' @param ... additional arguments passed to the method.
0172 #' @aliases spark.gbt,SparkDataFrame,formula-method
0173 #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model.
0174 #' @rdname spark.gbt
0175 #' @name spark.gbt
0176 #' @examples
0177 #' \dontrun{
0178 #' # fit a Gradient Boosted Tree Regression Model
0179 #' df <- createDataFrame(longley)
0180 #' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
0181 #'
0182 #' # get the summary of the model
0183 #' summary(model)
0184 #'
0185 #' # make predictions
0186 #' predictions <- predict(model, df)
0187 #'
0188 #' # save and load the model
0189 #' path <- "path/to/model"
0190 #' write.ml(model, path)
0191 #' savedModel <- read.ml(path)
0192 #' summary(savedModel)
0193 #'
0194 #' # fit a Gradient Boosted Tree Classification Model
0195 #' # label must be binary - Only binary classification is supported for GBT.
0196 #' t <- as.data.frame(Titanic)
0197 #' df <- createDataFrame(t)
0198 #' model <- spark.gbt(df, Survived ~ Age + Freq, "classification")
0199 #'
0200 #' # numeric label is also supported
0201 #' t2 <- as.data.frame(Titanic)
0202 #' t2$NumericGender <- ifelse(t2$Sex == "Male", 0, 1)
0203 #' df <- createDataFrame(t2)
0204 #' model <- spark.gbt(df, NumericGender ~ ., type = "classification")
0205 #' }
0206 #' @note spark.gbt since 2.1.0
0207 setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
0208           function(data, formula, type = c("regression", "classification"),
0209                    maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL,
0210                    seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0,
0211                    checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE,
0212                    handleInvalid = c("error", "keep", "skip")) {
0213             type <- match.arg(type)
0214             formula <- paste(deparse(formula), collapse = "")
0215             if (!is.null(seed)) {
0216               seed <- as.character(as.integer(seed))
0217             }
0218             switch(type,
0219                    regression = {
0220                      if (is.null(lossType)) lossType <- "squared"
0221                      lossType <- match.arg(lossType, c("squared", "absolute"))
0222                      jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper",
0223                                          "fit", data@sdf, formula, as.integer(maxDepth),
0224                                          as.integer(maxBins), as.integer(maxIter),
0225                                          as.numeric(stepSize), as.integer(minInstancesPerNode),
0226                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
0227                                          lossType, seed, as.numeric(subsamplingRate),
0228                                          as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
0229                      new("GBTRegressionModel", jobj = jobj)
0230                    },
0231                    classification = {
0232                      handleInvalid <- match.arg(handleInvalid)
0233                      if (is.null(lossType)) lossType <- "logistic"
0234                      lossType <- match.arg(lossType, "logistic")
0235                      jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper",
0236                                          "fit", data@sdf, formula, as.integer(maxDepth),
0237                                          as.integer(maxBins), as.integer(maxIter),
0238                                          as.numeric(stepSize), as.integer(minInstancesPerNode),
0239                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
0240                                          lossType, seed, as.numeric(subsamplingRate),
0241                                          as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
0242                                          handleInvalid)
0243                      new("GBTClassificationModel", jobj = jobj)
0244                    }
0245             )
0246           })
0247 
0248 #  Get the summary of a Gradient Boosted Tree Regression Model
0249 
0250 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0251 #'         The list of components includes \code{formula} (formula),
0252 #'         \code{numFeatures} (number of features), \code{features} (list of features),
0253 #'         \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees),
0254 #'         \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
0255 #' @rdname spark.gbt
0256 #' @aliases summary,GBTRegressionModel-method
0257 #' @note summary(GBTRegressionModel) since 2.1.0
0258 setMethod("summary", signature(object = "GBTRegressionModel"),
0259           function(object) {
0260             ans <- summary.treeEnsemble(object)
0261             class(ans) <- "summary.GBTRegressionModel"
0262             ans
0263           })
0264 
0265 #  Prints the summary of Gradient Boosted Tree Regression Model
0266 
0267 #' @param x summary object of Gradient Boosted Tree regression model or classification model
0268 #'          returned by \code{summary}.
0269 #' @rdname spark.gbt
0270 #' @note print.summary.GBTRegressionModel since 2.1.0
0271 print.summary.GBTRegressionModel <- function(x, ...) {
0272   print.summary.treeEnsemble(x)
0273 }
0274 
0275 #  Get the summary of a Gradient Boosted Tree Classification Model
0276 
0277 #' @rdname spark.gbt
0278 #' @aliases summary,GBTClassificationModel-method
0279 #' @note summary(GBTClassificationModel) since 2.1.0
0280 setMethod("summary", signature(object = "GBTClassificationModel"),
0281           function(object) {
0282             ans <- summary.treeEnsemble(object)
0283             class(ans) <- "summary.GBTClassificationModel"
0284             ans
0285           })
0286 
0287 #  Prints the summary of Gradient Boosted Tree Classification Model
0288 
0289 #' @rdname spark.gbt
0290 #' @note print.summary.GBTClassificationModel since 2.1.0
0291 print.summary.GBTClassificationModel <- function(x, ...) {
0292   print.summary.treeEnsemble(x)
0293 }
0294 
0295 #  Makes predictions from a Gradient Boosted Tree Regression model or Classification model
0296 
0297 #' @param newData a SparkDataFrame for testing.
0298 #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
0299 #'         "prediction".
0300 #' @rdname spark.gbt
0301 #' @aliases predict,GBTRegressionModel-method
0302 #' @note predict(GBTRegressionModel) since 2.1.0
0303 setMethod("predict", signature(object = "GBTRegressionModel"),
0304           function(object, newData) {
0305             predict_internal(object, newData)
0306           })
0307 
0308 #' @rdname spark.gbt
0309 #' @aliases predict,GBTClassificationModel-method
0310 #' @note predict(GBTClassificationModel) since 2.1.0
0311 setMethod("predict", signature(object = "GBTClassificationModel"),
0312           function(object, newData) {
0313             predict_internal(object, newData)
0314           })
0315 
0316 #  Save the Gradient Boosted Tree Regression or Classification model to the input path.
0317 
0318 #' @param object A fitted Gradient Boosted Tree regression model or classification model.
0319 #' @param path The directory where the model is saved.
0320 #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
0321 #'                  which means throw exception if the output path exists.
0322 #' @aliases write.ml,GBTRegressionModel,character-method
0323 #' @rdname spark.gbt
0324 #' @note write.ml(GBTRegressionModel, character) since 2.1.0
0325 setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"),
0326           function(object, path, overwrite = FALSE) {
0327             write_internal(object, path, overwrite)
0328           })
0329 
0330 #' @aliases write.ml,GBTClassificationModel,character-method
0331 #' @rdname spark.gbt
0332 #' @note write.ml(GBTClassificationModel, character) since 2.1.0
0333 setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"),
0334           function(object, path, overwrite = FALSE) {
0335             write_internal(object, path, overwrite)
0336           })
0337 
0338 #' Random Forest Model for Regression and Classification
0339 #'
0340 #' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on
0341 #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest
0342 #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
0343 #' save/load fitted models.
0344 #' For more details, see
0345 # nolint start
0346 #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{
0347 #' Random Forest Regression} and
0348 #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{
0349 #' Random Forest Classification}
0350 # nolint end
0351 #'
0352 #' @param data a SparkDataFrame for training.
0353 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0354 #'                operators are supported, including '~', ':', '+', and '-'.
0355 #' @param type type of model, one of "regression" or "classification", to fit
0356 #' @param maxDepth Maximum depth of the tree (>= 0).
0357 #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
0358 #'                how to split on features at each node. More bins give higher granularity. Must be
0359 #'                >= 2 and >= number of categories in any categorical feature.
0360 #' @param numTrees Number of trees to train (>= 1).
0361 #' @param impurity Criterion used for information gain calculation.
0362 #'                 For regression, must be "variance". For classification, must be one of
0363 #'                 "entropy" and "gini", default is "gini".
0364 #' @param featureSubsetStrategy The number of features to consider for splits at each tree node.
0365 #'                              Supported options: "auto" (choose automatically for task: If
0366 #'                                                 numTrees == 1, set to "all." If numTrees > 1
0367 #'                                                 (forest), set to "sqrt" for classification and
0368 #'                                                 to "onethird" for regression),
0369 #'                                                 "all" (use all features),
0370 #'                                                 "onethird" (use 1/3 of the features),
0371 #'                                                 "sqrt" (use sqrt(number of features)),
0372 #'                                                 "log2" (use log2(number of features)),
0373 #'                                                 "n": (when n is in the range (0, 1.0], use
0374 #'                                                 n * number of features. When n is in the range
0375 #'                                                 (1, number of features), use n features).
0376 #'                                                 Default is "auto".
0377 #' @param seed integer seed for random number generation.
0378 #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in
0379 #'                        range (0, 1].
0380 #' @param minInstancesPerNode Minimum number of instances each child must have after split.
0381 #' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
0382 #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
0383 #'                           Note: this setting will be ignored if the checkpoint directory is not
0384 #'                           set.
0385 #' @param maxMemoryInMB Maximum memory in MiB allocated to histogram aggregation.
0386 #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
0387 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
0388 #'                     can speed up training of deeper trees. Users can set how often should the
0389 #'                     cache be checkpointed or disable it by setting checkpointInterval.
0390 #' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and
0391 #'                      label column of string type in classification model.
0392 #'                      Supported options: "skip" (filter out rows with invalid data),
0393 #'                                         "error" (throw an error), "keep" (put invalid data in
0394 #'                                         a special additional bucket, at index numLabels). Default
0395 #'                                         is "error".
0396 #' @param bootstrap Whether bootstrap samples are used when building trees.
0397 #' @param ... additional arguments passed to the method.
0398 #' @aliases spark.randomForest,SparkDataFrame,formula-method
0399 #' @return \code{spark.randomForest} returns a fitted Random Forest model.
0400 #' @rdname spark.randomForest
0401 #' @name spark.randomForest
0402 #' @examples
0403 #' \dontrun{
0404 #' # fit a Random Forest Regression Model
0405 #' df <- createDataFrame(longley)
0406 #' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
0407 #'
0408 #' # get the summary of the model
0409 #' summary(model)
0410 #'
0411 #' # make predictions
0412 #' predictions <- predict(model, df)
0413 #'
0414 #' # save and load the model
0415 #' path <- "path/to/model"
0416 #' write.ml(model, path)
0417 #' savedModel <- read.ml(path)
0418 #' summary(savedModel)
0419 #'
0420 #' # fit a Random Forest Classification Model
0421 #' t <- as.data.frame(Titanic)
0422 #' df <- createDataFrame(t)
0423 #' model <- spark.randomForest(df, Survived ~ Freq + Age, "classification")
0424 #' }
0425 #' @note spark.randomForest since 2.1.0
0426 setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"),
0427           function(data, formula, type = c("regression", "classification"),
0428                    maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
0429                    featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
0430                    minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
0431                    maxMemoryInMB = 256, cacheNodeIds = FALSE,
0432                    handleInvalid = c("error", "keep", "skip"),
0433                    bootstrap = TRUE) {
0434             type <- match.arg(type)
0435             formula <- paste(deparse(formula), collapse = "")
0436             if (!is.null(seed)) {
0437               seed <- as.character(as.integer(seed))
0438             }
0439             switch(type,
0440                    regression = {
0441                      if (is.null(impurity)) impurity <- "variance"
0442                      impurity <- match.arg(impurity, "variance")
0443                      jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper",
0444                                          "fit", data@sdf, formula, as.integer(maxDepth),
0445                                          as.integer(maxBins), as.integer(numTrees),
0446                                          impurity, as.integer(minInstancesPerNode),
0447                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
0448                                          as.character(featureSubsetStrategy), seed,
0449                                          as.numeric(subsamplingRate),
0450                                          as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
0451                                          as.logical(bootstrap))
0452                      new("RandomForestRegressionModel", jobj = jobj)
0453                    },
0454                    classification = {
0455                      handleInvalid <- match.arg(handleInvalid)
0456                      if (is.null(impurity)) impurity <- "gini"
0457                      impurity <- match.arg(impurity, c("gini", "entropy"))
0458                      jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
0459                                          "fit", data@sdf, formula, as.integer(maxDepth),
0460                                          as.integer(maxBins), as.integer(numTrees),
0461                                          impurity, as.integer(minInstancesPerNode),
0462                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
0463                                          as.character(featureSubsetStrategy), seed,
0464                                          as.numeric(subsamplingRate),
0465                                          as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
0466                                          handleInvalid, as.logical(bootstrap))
0467                      new("RandomForestClassificationModel", jobj = jobj)
0468                    }
0469             )
0470           })
0471 
0472 #  Get the summary of a Random Forest Regression Model
0473 
0474 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0475 #'         The list of components includes \code{formula} (formula),
0476 #'         \code{numFeatures} (number of features), \code{features} (list of features),
0477 #'         \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees),
0478 #'         \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
0479 #' @rdname spark.randomForest
0480 #' @aliases summary,RandomForestRegressionModel-method
0481 #' @note summary(RandomForestRegressionModel) since 2.1.0
0482 setMethod("summary", signature(object = "RandomForestRegressionModel"),
0483           function(object) {
0484             ans <- summary.treeEnsemble(object)
0485             class(ans) <- "summary.RandomForestRegressionModel"
0486             ans
0487           })
0488 
0489 #  Prints the summary of Random Forest Regression Model
0490 
0491 #' @param x summary object of Random Forest regression model or classification model
0492 #'          returned by \code{summary}.
0493 #' @rdname spark.randomForest
0494 #' @note print.summary.RandomForestRegressionModel since 2.1.0
0495 print.summary.RandomForestRegressionModel <- function(x, ...) {
0496   print.summary.treeEnsemble(x)
0497 }
0498 
0499 #  Get the summary of a Random Forest Classification Model
0500 
0501 #' @rdname spark.randomForest
0502 #' @aliases summary,RandomForestClassificationModel-method
0503 #' @note summary(RandomForestClassificationModel) since 2.1.0
0504 setMethod("summary", signature(object = "RandomForestClassificationModel"),
0505           function(object) {
0506             ans <- summary.treeEnsemble(object)
0507             class(ans) <- "summary.RandomForestClassificationModel"
0508             ans
0509           })
0510 
0511 #  Prints the summary of Random Forest Classification Model
0512 
0513 #' @rdname spark.randomForest
0514 #' @note print.summary.RandomForestClassificationModel since 2.1.0
0515 print.summary.RandomForestClassificationModel <- function(x, ...) {
0516   print.summary.treeEnsemble(x)
0517 }
0518 
0519 #  Makes predictions from a Random Forest Regression model or Classification model
0520 
0521 #' @param newData a SparkDataFrame for testing.
0522 #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
0523 #'         "prediction".
0524 #' @rdname spark.randomForest
0525 #' @aliases predict,RandomForestRegressionModel-method
0526 #' @note predict(RandomForestRegressionModel) since 2.1.0
0527 setMethod("predict", signature(object = "RandomForestRegressionModel"),
0528           function(object, newData) {
0529             predict_internal(object, newData)
0530           })
0531 
0532 #' @rdname spark.randomForest
0533 #' @aliases predict,RandomForestClassificationModel-method
0534 #' @note predict(RandomForestClassificationModel) since 2.1.0
0535 setMethod("predict", signature(object = "RandomForestClassificationModel"),
0536           function(object, newData) {
0537             predict_internal(object, newData)
0538           })
0539 
0540 #  Save the Random Forest Regression or Classification model to the input path.
0541 
0542 #' @param object A fitted Random Forest regression model or classification model.
0543 #' @param path The directory where the model is saved.
0544 #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
0545 #'                  which means throw exception if the output path exists.
0546 #'
0547 #' @aliases write.ml,RandomForestRegressionModel,character-method
0548 #' @rdname spark.randomForest
0549 #' @note write.ml(RandomForestRegressionModel, character) since 2.1.0
0550 setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"),
0551           function(object, path, overwrite = FALSE) {
0552             write_internal(object, path, overwrite)
0553           })
0554 
0555 #' @aliases write.ml,RandomForestClassificationModel,character-method
0556 #' @rdname spark.randomForest
0557 #' @note write.ml(RandomForestClassificationModel, character) since 2.1.0
0558 setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"),
0559           function(object, path, overwrite = FALSE) {
0560             write_internal(object, path, overwrite)
0561           })
0562 
0563 #' Decision Tree Model for Regression and Classification
0564 #'
0565 #' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
0566 #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
0567 #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
0568 #' save/load fitted models.
0569 #' For more details, see
0570 # nolint start
0571 #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
0572 #' Decision Tree Regression} and
0573 #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
0574 #' Decision Tree Classification}
0575 # nolint end
0576 #'
0577 #' @param data a SparkDataFrame for training.
0578 #' @param formula a symbolic description of the model to be fitted. Currently only a few formula
0579 #'                operators are supported, including '~', ':', '+', and '-'.
0580 #' @param type type of model, one of "regression" or "classification", to fit
0581 #' @param maxDepth Maximum depth of the tree (>= 0).
0582 #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
0583 #'                how to split on features at each node. More bins give higher granularity. Must be
0584 #'                >= 2 and >= number of categories in any categorical feature.
0585 #' @param impurity Criterion used for information gain calculation.
0586 #'                 For regression, must be "variance". For classification, must be one of
0587 #'                 "entropy" and "gini", default is "gini".
0588 #' @param seed integer seed for random number generation.
0589 #' @param minInstancesPerNode Minimum number of instances each child must have after split.
0590 #' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
0591 #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
0592 #'                           Note: this setting will be ignored if the checkpoint directory is not
0593 #'                           set.
0594 #' @param maxMemoryInMB Maximum memory in MiB allocated to histogram aggregation.
0595 #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
0596 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
0597 #'                     can speed up training of deeper trees. Users can set how often should the
0598 #'                     cache be checkpointed or disable it by setting checkpointInterval.
0599 #' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and
0600 #'                      label column of string type in classification model.
0601 #'                      Supported options: "skip" (filter out rows with invalid data),
0602 #'                                         "error" (throw an error), "keep" (put invalid data in
0603 #'                                         a special additional bucket, at index numLabels). Default
0604 #'                                         is "error".
0605 #' @param ... additional arguments passed to the method.
0606 #' @aliases spark.decisionTree,SparkDataFrame,formula-method
0607 #' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
0608 #' @rdname spark.decisionTree
0609 #' @name spark.decisionTree
0610 #' @examples
0611 #' \dontrun{
0612 #' # fit a Decision Tree Regression Model
0613 #' df <- createDataFrame(longley)
0614 #' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
0615 #'
0616 #' # get the summary of the model
0617 #' summary(model)
0618 #'
0619 #' # make predictions
0620 #' predictions <- predict(model, df)
0621 #'
0622 #' # save and load the model
0623 #' path <- "path/to/model"
0624 #' write.ml(model, path)
0625 #' savedModel <- read.ml(path)
0626 #' summary(savedModel)
0627 #'
0628 #' # fit a Decision Tree Classification Model
0629 #' t <- as.data.frame(Titanic)
0630 #' df <- createDataFrame(t)
0631 #' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
0632 #' }
0633 #' @note spark.decisionTree since 2.3.0
0634 setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
0635           function(data, formula, type = c("regression", "classification"),
0636                    maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
0637                    minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
0638                    maxMemoryInMB = 256, cacheNodeIds = FALSE,
0639                    handleInvalid = c("error", "keep", "skip")) {
0640             type <- match.arg(type)
0641             formula <- paste(deparse(formula), collapse = "")
0642             if (!is.null(seed)) {
0643               seed <- as.character(as.integer(seed))
0644             }
0645             switch(type,
0646                    regression = {
0647                      if (is.null(impurity)) impurity <- "variance"
0648                      impurity <- match.arg(impurity, "variance")
0649                      jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
0650                                          "fit", data@sdf, formula, as.integer(maxDepth),
0651                                          as.integer(maxBins), impurity,
0652                                          as.integer(minInstancesPerNode), as.numeric(minInfoGain),
0653                                          as.integer(checkpointInterval), seed,
0654                                          as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
0655                      new("DecisionTreeRegressionModel", jobj = jobj)
0656                    },
0657                    classification = {
0658                      handleInvalid <- match.arg(handleInvalid)
0659                      if (is.null(impurity)) impurity <- "gini"
0660                      impurity <- match.arg(impurity, c("gini", "entropy"))
0661                      jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
0662                                          "fit", data@sdf, formula, as.integer(maxDepth),
0663                                          as.integer(maxBins), impurity,
0664                                          as.integer(minInstancesPerNode), as.numeric(minInfoGain),
0665                                          as.integer(checkpointInterval), seed,
0666                                          as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
0667                                          handleInvalid)
0668                      new("DecisionTreeClassificationModel", jobj = jobj)
0669                    }
0670             )
0671           })
0672 
0673 #  Get the summary of a Decision Tree Regression Model
0674 
0675 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0676 #'         The list of components includes \code{formula} (formula),
0677 #'         \code{numFeatures} (number of features), \code{features} (list of features),
0678 #'         \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of
0679 #'         trees).
0680 #' @rdname spark.decisionTree
0681 #' @aliases summary,DecisionTreeRegressionModel-method
0682 #' @note summary(DecisionTreeRegressionModel) since 2.3.0
0683 setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
0684           function(object) {
0685             ans <- summary.decisionTree(object)
0686             class(ans) <- "summary.DecisionTreeRegressionModel"
0687             ans
0688           })
0689 
0690 #  Prints the summary of Decision Tree Regression Model
0691 
0692 #' @param x summary object of Decision Tree regression model or classification model
0693 #'          returned by \code{summary}.
0694 #' @rdname spark.decisionTree
0695 #' @note print.summary.DecisionTreeRegressionModel since 2.3.0
0696 print.summary.DecisionTreeRegressionModel <- function(x, ...) {
0697   print.summary.decisionTree(x)
0698 }
0699 
0700 #  Get the summary of a Decision Tree Classification Model
0701 
0702 #' @rdname spark.decisionTree
0703 #' @aliases summary,DecisionTreeClassificationModel-method
0704 #' @note summary(DecisionTreeClassificationModel) since 2.3.0
0705 setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
0706           function(object) {
0707             ans <- summary.decisionTree(object)
0708             class(ans) <- "summary.DecisionTreeClassificationModel"
0709             ans
0710           })
0711 
0712 #  Prints the summary of Decision Tree Classification Model
0713 
0714 #' @rdname spark.decisionTree
0715 #' @note print.summary.DecisionTreeClassificationModel since 2.3.0
0716 print.summary.DecisionTreeClassificationModel <- function(x, ...) {
0717   print.summary.decisionTree(x)
0718 }
0719 
0720 #  Makes predictions from a Decision Tree Regression model or Classification model
0721 
0722 #' @param newData a SparkDataFrame for testing.
0723 #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
0724 #'         "prediction".
0725 #' @rdname spark.decisionTree
0726 #' @aliases predict,DecisionTreeRegressionModel-method
0727 #' @note predict(DecisionTreeRegressionModel) since 2.3.0
0728 setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
0729           function(object, newData) {
0730             predict_internal(object, newData)
0731           })
0732 
0733 #' @rdname spark.decisionTree
0734 #' @aliases predict,DecisionTreeClassificationModel-method
0735 #' @note predict(DecisionTreeClassificationModel) since 2.3.0
0736 setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
0737           function(object, newData) {
0738             predict_internal(object, newData)
0739           })
0740 
0741 #  Save the Decision Tree Regression or Classification model to the input path.
0742 
0743 #' @param object A fitted Decision Tree regression model or classification model.
0744 #' @param path The directory where the model is saved.
0745 #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
0746 #'                  which means throw exception if the output path exists.
0747 #'
0748 #' @aliases write.ml,DecisionTreeRegressionModel,character-method
0749 #' @rdname spark.decisionTree
0750 #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
0751 setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
0752           function(object, path, overwrite = FALSE) {
0753             write_internal(object, path, overwrite)
0754           })
0755 
0756 #' @aliases write.ml,DecisionTreeClassificationModel,character-method
0757 #' @rdname spark.decisionTree
0758 #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
0759 setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
0760           function(object, path, overwrite = FALSE) {
0761             write_internal(object, path, overwrite)
0762           })