0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements. See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License. You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017
0018 # mllib_recommendation.R: Provides methods for MLlib recommendation algorithms integration
0019
0020 #' S4 class that represents an ALSModel
0021 #'
0022 #' @param jobj a Java object reference to the backing Scala ALSWrapper
0023 #' @note ALSModel since 2.1.0
0024 setClass("ALSModel", representation(jobj = "jobj"))
0025
0026 #' Alternating Least Squares (ALS) for Collaborative Filtering
0027 #'
0028 #' \code{spark.als} learns latent factors in collaborative filtering via alternating least
0029 #' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict}
0030 #' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
0031 #'
0032 #' For more details, see
0033 #' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib:
0034 #' Collaborative Filtering}.
0035 #'
0036 #' @param data a SparkDataFrame for training.
0037 #' @param ratingCol column name for ratings.
0038 #' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
0039 #' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
0040 #' @param rank rank of the matrix factorization (> 0).
0041 #' @param regParam regularization parameter (>= 0).
0042 #' @param maxIter maximum number of iterations (>= 0).
0043 #' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
0044 #' @param implicitPrefs logical value indicating whether to use implicit preference.
0045 #' @param alpha alpha parameter in the implicit preference formulation (>= 0).
0046 #' @param seed integer seed for random number generation.
0047 #' @param numUserBlocks number of user blocks used to parallelize computation (> 0).
0048 #' @param numItemBlocks number of item blocks used to parallelize computation (> 0).
0049 #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1).
0050 #' Note: this setting will be ignored if the checkpoint directory is not
0051 #' set.
0052 #' @param ... additional argument(s) passed to the method.
0053 #' @return \code{spark.als} returns a fitted ALS model.
0054 #' @rdname spark.als
0055 #' @aliases spark.als,SparkDataFrame-method
0056 #' @name spark.als
0057 #' @examples
0058 #' \dontrun{
0059 #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
0060 #' list(2, 1, 1.0), list(2, 2, 5.0))
0061 #' df <- createDataFrame(ratings, c("user", "item", "rating"))
0062 #' model <- spark.als(df, "rating", "user", "item")
0063 #'
0064 #' # extract latent factors
0065 #' stats <- summary(model)
0066 #' userFactors <- stats$userFactors
0067 #' itemFactors <- stats$itemFactors
0068 #'
0069 #' # make predictions
0070 #' predicted <- predict(model, df)
0071 #' showDF(predicted)
0072 #'
0073 #' # save and load the model
0074 #' path <- "path/to/model"
0075 #' write.ml(model, path)
0076 #' savedModel <- read.ml(path)
0077 #' summary(savedModel)
0078 #'
0079 #' # set other arguments
0080 #' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
0081 #' regParam = 0.1, nonnegative = TRUE)
0082 #' statsS <- summary(modelS)
0083 #' }
0084 #' @note spark.als since 2.1.0
0085 #' @note the input rating dataframe to the ALS implementation should be deterministic.
0086 #' Nondeterministic data can cause failure during fitting ALS model. For example,
0087 #' an order-sensitive operation like sampling after a repartition makes dataframe output
0088 #' nondeterministic, like \code{sample(repartition(df, 2L), FALSE, 0.5, 1618L)}.
0089 #' Checkpointing sampled dataframe or adding a sort before sampling can help make the
0090 #' dataframe deterministic.
0091 setMethod("spark.als", signature(data = "SparkDataFrame"),
0092 function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
0093 rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE,
0094 implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
0095 checkpointInterval = 10, seed = 0) {
0096
0097 if (!is.numeric(rank) || rank <= 0) {
0098 stop("rank should be a positive number.")
0099 }
0100 if (!is.numeric(regParam) || regParam < 0) {
0101 stop("regParam should be a nonnegative number.")
0102 }
0103 if (!is.numeric(maxIter) || maxIter <= 0) {
0104 stop("maxIter should be a positive number.")
0105 }
0106
0107 jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
0108 "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
0109 regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
0110 as.integer(numUserBlocks), as.integer(numItemBlocks),
0111 as.integer(checkpointInterval), as.integer(seed))
0112 new("ALSModel", jobj = jobj)
0113 })
0114
0115 # Returns a summary of the ALS model produced by spark.als.
0116
0117 #' @param object a fitted ALS model.
0118 #' @return \code{summary} returns summary information of the fitted model, which is a list.
0119 #' The list includes \code{user} (the names of the user column),
0120 #' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors}
0121 #' (the estimated user factors), \code{itemFactors} (the estimated item factors),
0122 #' and \code{rank} (rank of the matrix factorization model).
0123 #' @rdname spark.als
0124 #' @aliases summary,ALSModel-method
0125 #' @note summary(ALSModel) since 2.1.0
0126 setMethod("summary", signature(object = "ALSModel"),
0127 function(object) {
0128 jobj <- object@jobj
0129 user <- callJMethod(jobj, "userCol")
0130 item <- callJMethod(jobj, "itemCol")
0131 rating <- callJMethod(jobj, "ratingCol")
0132 userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
0133 itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
0134 rank <- callJMethod(jobj, "rank")
0135 list(user = user, item = item, rating = rating, userFactors = userFactors,
0136 itemFactors = itemFactors, rank = rank)
0137 })
0138
0139 # Makes predictions from an ALS model or a model produced by spark.als.
0140
0141 #' @param newData a SparkDataFrame for testing.
0142 #' @return \code{predict} returns a SparkDataFrame containing predicted values.
0143 #' @rdname spark.als
0144 #' @aliases predict,ALSModel-method
0145 #' @note predict(ALSModel) since 2.1.0
0146 setMethod("predict", signature(object = "ALSModel"),
0147 function(object, newData) {
0148 predict_internal(object, newData)
0149 })
0150
0151 # Saves the ALS model to the input path.
0152
0153 #' @param path the directory where the model is saved.
0154 #' @param overwrite logical value indicating whether to overwrite if the output path
0155 #' already exists. Default is FALSE which means throw exception
0156 #' if the output path exists.
0157 #'
0158 #' @rdname spark.als
0159 #' @aliases write.ml,ALSModel,character-method
0160 #' @seealso \link{read.ml}
0161 #' @note write.ml(ALSModel, character) since 2.1.0
0162 setMethod("write.ml", signature(object = "ALSModel", path = "character"),
0163 function(object, path, overwrite = FALSE) {
0164 write_internal(object, path, overwrite)
0165 })