Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # mllib_recommendation.R: Provides methods for MLlib recommendation algorithms integration
0019 
0020 #' S4 class that represents an ALSModel
0021 #'
0022 #' @param jobj a Java object reference to the backing Scala ALSWrapper
0023 #' @note ALSModel since 2.1.0
0024 setClass("ALSModel", representation(jobj = "jobj"))
0025 
0026 #' Alternating Least Squares (ALS) for Collaborative Filtering
0027 #'
0028 #' \code{spark.als} learns latent factors in collaborative filtering via alternating least
0029 #' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict}
0030 #' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
0031 #'
0032 #' For more details, see
0033 #' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib:
0034 #' Collaborative Filtering}.
0035 #'
0036 #' @param data a SparkDataFrame for training.
0037 #' @param ratingCol column name for ratings.
0038 #' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
0039 #' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
0040 #' @param rank rank of the matrix factorization (> 0).
0041 #' @param regParam regularization parameter (>= 0).
0042 #' @param maxIter maximum number of iterations (>= 0).
0043 #' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
0044 #' @param implicitPrefs logical value indicating whether to use implicit preference.
0045 #' @param alpha alpha parameter in the implicit preference formulation (>= 0).
0046 #' @param seed integer seed for random number generation.
0047 #' @param numUserBlocks number of user blocks used to parallelize computation (> 0).
0048 #' @param numItemBlocks number of item blocks used to parallelize computation (> 0).
0049 #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1).
0050 #'                           Note: this setting will be ignored if the checkpoint directory is not
0051 #'                           set.
0052 #' @param ... additional argument(s) passed to the method.
0053 #' @return \code{spark.als} returns a fitted ALS model.
0054 #' @rdname spark.als
0055 #' @aliases spark.als,SparkDataFrame-method
0056 #' @name spark.als
0057 #' @examples
0058 #' \dontrun{
0059 #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
0060 #'                 list(2, 1, 1.0), list(2, 2, 5.0))
0061 #' df <- createDataFrame(ratings, c("user", "item", "rating"))
0062 #' model <- spark.als(df, "rating", "user", "item")
0063 #'
0064 #' # extract latent factors
0065 #' stats <- summary(model)
0066 #' userFactors <- stats$userFactors
0067 #' itemFactors <- stats$itemFactors
0068 #'
0069 #' # make predictions
0070 #' predicted <- predict(model, df)
0071 #' showDF(predicted)
0072 #'
0073 #' # save and load the model
0074 #' path <- "path/to/model"
0075 #' write.ml(model, path)
0076 #' savedModel <- read.ml(path)
0077 #' summary(savedModel)
0078 #'
0079 #' # set other arguments
0080 #' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
0081 #'                     regParam = 0.1, nonnegative = TRUE)
0082 #' statsS <- summary(modelS)
0083 #' }
0084 #' @note spark.als since 2.1.0
0085 #' @note the input rating dataframe to the ALS implementation should be deterministic.
0086 #'       Nondeterministic data can cause failure during fitting ALS model. For example,
0087 #'       an order-sensitive operation like sampling after a repartition makes dataframe output
0088 #'       nondeterministic, like \code{sample(repartition(df, 2L), FALSE, 0.5, 1618L)}.
0089 #'       Checkpointing sampled dataframe or adding a sort before sampling can help make the
0090 #'       dataframe deterministic.
0091 setMethod("spark.als", signature(data = "SparkDataFrame"),
0092           function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
0093                    rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE,
0094                    implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
0095                    checkpointInterval = 10, seed = 0) {
0096 
0097             if (!is.numeric(rank) || rank <= 0) {
0098               stop("rank should be a positive number.")
0099             }
0100             if (!is.numeric(regParam) || regParam < 0) {
0101               stop("regParam should be a nonnegative number.")
0102             }
0103             if (!is.numeric(maxIter) || maxIter <= 0) {
0104               stop("maxIter should be a positive number.")
0105             }
0106 
0107             jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
0108                                 "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
0109                                 regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
0110                                 as.integer(numUserBlocks), as.integer(numItemBlocks),
0111                                 as.integer(checkpointInterval), as.integer(seed))
0112             new("ALSModel", jobj = jobj)
0113           })
0114 
0115 #  Returns a summary of the ALS model produced by spark.als.
0116 
0117 #' @param object a fitted ALS model.
0118 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0119 #'         The list includes \code{user} (the names of the user column),
0120 #'         \code{item} (the item column), \code{rating} (the rating column), \code{userFactors}
0121 #'         (the estimated user factors), \code{itemFactors} (the estimated item factors),
0122 #'         and \code{rank} (rank of the matrix factorization model).
0123 #' @rdname spark.als
0124 #' @aliases summary,ALSModel-method
0125 #' @note summary(ALSModel) since 2.1.0
0126 setMethod("summary", signature(object = "ALSModel"),
0127           function(object) {
0128             jobj <- object@jobj
0129             user <- callJMethod(jobj, "userCol")
0130             item <- callJMethod(jobj, "itemCol")
0131             rating <- callJMethod(jobj, "ratingCol")
0132             userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
0133             itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
0134             rank <- callJMethod(jobj, "rank")
0135             list(user = user, item = item, rating = rating, userFactors = userFactors,
0136                  itemFactors = itemFactors, rank = rank)
0137           })
0138 
0139 #  Makes predictions from an ALS model or a model produced by spark.als.
0140 
0141 #' @param newData a SparkDataFrame for testing.
0142 #' @return \code{predict} returns a SparkDataFrame containing predicted values.
0143 #' @rdname spark.als
0144 #' @aliases predict,ALSModel-method
0145 #' @note predict(ALSModel) since 2.1.0
0146 setMethod("predict", signature(object = "ALSModel"),
0147           function(object, newData) {
0148             predict_internal(object, newData)
0149           })
0150 
0151 #  Saves the ALS model to the input path.
0152 
0153 #' @param path the directory where the model is saved.
0154 #' @param overwrite logical value indicating whether to overwrite if the output path
0155 #'                  already exists. Default is FALSE which means throw exception
0156 #'                  if the output path exists.
0157 #'
0158 #' @rdname spark.als
0159 #' @aliases write.ml,ALSModel,character-method
0160 #' @seealso \link{read.ml}
0161 #' @note write.ml(ALSModel, character) since 2.1.0
0162 setMethod("write.ml", signature(object = "ALSModel", path = "character"),
0163           function(object, path, overwrite = FALSE) {
0164             write_internal(object, path, overwrite)
0165           })