This is sample R code for collaborative filtering.
First, we define the function matrix_factorization which is used to decompose a matrix R (m*n) into the product of two matrices P (m*K) and Q (K*n). alpha and beta are coefficients for the gradient descent algorithm, which is used to find the optimal minimum/maximum. Steps denotes the number of interations.
matrix_factorization <- function(R, P, Q, K, steps=5000, alpha=0.0002, beta=0.02) {
Q <- t(Q)
for (step in 1:steps) {
for (i in 1:nrow(R)) {
for (j in 1:ncol(R)) {
if (R[i, j] > 0) {
# calculate error between the true rating matrix and the approximated rating matrix
eij <- R[i, j] - sum(P[i,] * Q[,j])
for (k in 1:K) {
# calculate gradient with alpha and beta parameter
P[i, k] <- P[i, k] + alpha * (2 * eij * Q[k, j] - beta * P[i, k])
Q[k, j] <- Q[k, j] + alpha * (2 * eij * P[i, k] - beta * Q[k, j])
}}}}
eR <- P %*% Q
e <- 0
for (i in 1:nrow(R)) {
for (j in 1:ncol(R)) {
if (R[i, j] > 0) {
e <- e + (R[i, j] - sum(P[i,] * Q[,j]))^2
for (k in 1:K) {
e <- e + (beta/2) * (P[i, k]^2 + Q[k, j]^2)
}}}}
if (e < 0.001){break}}
return(list(P = P, Q = t(Q)))
}
This is used to generate random numbers.
set.seed(123)
Matrix R is the rating matrix to be decomposed; 0 means no rating.
R <- matrix(c(5, 3, 0, 1,
4, 0, 0, 1,
1, 1, 0, 5,
1, 0, 0, 4,
0, 1, 5, 4,
2, 1, 3, 0), nrow = 6, ncol = 4, byrow = TRUE)
# N: num of User
N <- nrow(R)
# M: num of Movie
M <- ncol(R)
# Num of Features
K <- 3
Starting with two random matrices P and Q:
P <- matrix(runif(N * K), nrow = N, ncol = K)
Q <- matrix(runif(M * K), nrow = M, ncol = K)
Running the algorithm to decompose R into the product of P and Q:
result <- matrix_factorization(R, P, Q, K)
nP <- result$P
nQ <- result$Q
nR <- nP %*% t(nQ)
Printing P, Q, and the predicted rating matrix (nR):
print(nP)
## [,1] [,2] [,3]
## [1,] 0.2463714 1.2785584 1.884240882
## [2,] 0.2058249 1.1525078 1.385007010
## [3,] 1.8185968 1.1550282 -0.255033789
## [4,] 1.7291866 0.5349949 0.223180524
## [5,] 1.3502813 1.2750629 -0.002350487
## [6,] 0.5163619 1.0115940 0.310546768
print(nQ)
## [,1] [,2] [,3]
## [1,] -0.09132514 1.4309696 1.6931080
## [2,] 0.27092979 0.5848768 1.1271429
## [3,] 1.93218953 1.8271619 0.5202290
## [4,] 2.04541001 0.9883406 -0.4065921
print(nR)
## [,1] [,2] [,3] [,4]
## [1,] 4.9973016 2.9383573 3.792406 1.0014643
## [2,] 3.9753730 2.2909401 3.224032 0.9969337
## [3,] 1.0549269 0.8808018 5.491621 4.9650321
## [4,] 0.9855119 1.0329506 4.434744 3.9749094
## [5,] 1.6972819 1.1089368 4.937523 4.0230309
## [6,] 1.9261927 1.0815863 3.007611 1.9297053