Added a method to enable bulk prediction

This commit is contained in:
Hossein Falaki 2014-01-03 15:34:16 -08:00
parent 0475ca8f81
commit 67f937ec22

View file

@ -20,7 +20,9 @@ package org.apache.spark.mllib.recommendation
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.jblas._
import java.nio.{ByteOrder, ByteBuffer}
/**
* Model representing the result of matrix factorization.
@ -44,6 +46,26 @@ class MatrixFactorizationModel(
userVector.dot(productVector)
}
// TODO: Figure out what good bulk prediction methods would look like.
/**
* Predict the rating of many users for many products.
* The output RDD has an element per each element in the input RDD (including all duplicates)
* unless a user or product is missing in the training set.
*
* @param usersProducts RDD of (user, product) pairs.
* @return RDD of Ratings.
*/
def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
val users = userFeatures.join(usersProducts).map{
case (user, (uFeatures, product)) => (product, (user, uFeatures))
}
users.join(productFeatures).map {
case (product, ((user, uFeatures), pFeatures)) =>
val userVector = new DoubleMatrix(uFeatures)
val productVector = new DoubleMatrix(pFeatures)
Rating(user, product, userVector.dot(productVector))
}
}
// TODO: Figure out what other good bulk prediction methods would look like.
// Probably want a way to get the top users for a product or vice-versa.
}