Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # mllib_utils.R: Utilities for MLlib integration
0019 
0020 # Integration with R's standard functions.
0021 # Most of MLlib's argorithms are provided in two flavours:
0022 # - a specialization of the default R methods (glm). These methods try to respect
0023 #   the inputs and the outputs of R's method to the largest extent, but some small differences
0024 #   may exist.
0025 # - a set of methods that reflect the arguments of the other languages supported by Spark. These
0026 #   methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc.
0027 
0028 #' Saves the MLlib model to the input path
0029 #'
0030 #' Saves the MLlib model to the input path. For more information, see the specific
0031 #' MLlib model below.
0032 #' @rdname write.ml
0033 #' @name write.ml
0034 #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
0035 #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
0036 #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
0037 #' @seealso \link{spark.kmeans},
0038 #' @seealso \link{spark.lda}, \link{spark.logit},
0039 #' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
0040 #' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear},
0041 #' @seealso \link{read.ml}
0042 NULL
0043 
0044 #' Makes predictions from a MLlib model
0045 #'
0046 #' Makes predictions from a MLlib model. For more information, see the specific
0047 #' MLlib model below.
0048 #' @rdname predict
0049 #' @name predict
0050 #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
0051 #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
0052 #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
0053 #' @seealso \link{spark.kmeans},
0054 #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
0055 #' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
0056 NULL
0057 
0058 write_internal <- function(object, path, overwrite = FALSE) {
0059   writer <- callJMethod(object@jobj, "write")
0060   if (overwrite) {
0061     writer <- callJMethod(writer, "overwrite")
0062   }
0063   invisible(callJMethod(writer, "save", path))
0064 }
0065 
0066 predict_internal <- function(object, newData) {
0067   dataFrame(callJMethod(object@jobj, "transform", newData@sdf))
0068 }
0069 
0070 #' Load a fitted MLlib model from the input path.
0071 #'
0072 #' @param path path of the model to read.
0073 #' @return A fitted MLlib model.
0074 #' @rdname read.ml
0075 #' @name read.ml
0076 #' @seealso \link{write.ml}
0077 #' @examples
0078 #' \dontrun{
0079 #' path <- "path/to/model"
0080 #' model <- read.ml(path)
0081 #' }
0082 #' @note read.ml since 2.0.0
0083 read.ml <- function(path) {
0084   path <- suppressWarnings(normalizePath(path))
0085   sparkSession <- getSparkSession()
0086   callJStatic("org.apache.spark.ml.r.RWrappers", "session", sparkSession)
0087   jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
0088   if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
0089     new("NaiveBayesModel", jobj = jobj)
0090   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
0091     new("AFTSurvivalRegressionModel", jobj = jobj)
0092   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
0093     new("GeneralizedLinearRegressionModel", jobj = jobj)
0094   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
0095     new("KMeansModel", jobj = jobj)
0096   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
0097     new("LDAModel", jobj = jobj)
0098   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) {
0099     new("MultilayerPerceptronClassificationModel", jobj = jobj)
0100   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
0101     new("IsotonicRegressionModel", jobj = jobj)
0102   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
0103     new("GaussianMixtureModel", jobj = jobj)
0104   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
0105     new("ALSModel", jobj = jobj)
0106   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) {
0107     new("LogisticRegressionModel", jobj = jobj)
0108   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) {
0109     new("RandomForestRegressionModel", jobj = jobj)
0110   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
0111     new("RandomForestClassificationModel", jobj = jobj)
0112   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
0113     new("DecisionTreeRegressionModel", jobj = jobj)
0114   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
0115     new("DecisionTreeClassificationModel", jobj = jobj)
0116   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
0117     new("GBTRegressionModel", jobj = jobj)
0118   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) {
0119     new("GBTClassificationModel", jobj = jobj)
0120   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.BisectingKMeansWrapper")) {
0121     new("BisectingKMeansModel", jobj = jobj)
0122   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) {
0123     new("LinearSVCModel", jobj = jobj)
0124   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
0125     new("FPGrowthModel", jobj = jobj)
0126   } else {
0127     stop("Unsupported model: ", jobj)
0128   }
0129 }