Added a method to enable bulk prediction
This commit is contained in:
parent
0475ca8f81
commit
67f937ec22
|
@ -20,7 +20,9 @@ package org.apache.spark.mllib.recommendation
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.SparkContext._
|
||||
|
||||
|
||||
import org.jblas._
|
||||
import java.nio.{ByteOrder, ByteBuffer}
|
||||
|
||||
/**
|
||||
* Model representing the result of matrix factorization.
|
||||
|
@ -44,6 +46,26 @@ class MatrixFactorizationModel(
|
|||
userVector.dot(productVector)
|
||||
}
|
||||
|
||||
// TODO: Figure out what good bulk prediction methods would look like.
|
||||
/**
|
||||
* Predict the rating of many users for many products.
|
||||
* The output RDD has an element per each element in the input RDD (including all duplicates)
|
||||
* unless a user or product is missing in the training set.
|
||||
*
|
||||
* @param usersProducts RDD of (user, product) pairs.
|
||||
* @return RDD of Ratings.
|
||||
*/
|
||||
def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
|
||||
val users = userFeatures.join(usersProducts).map{
|
||||
case (user, (uFeatures, product)) => (product, (user, uFeatures))
|
||||
}
|
||||
users.join(productFeatures).map {
|
||||
case (product, ((user, uFeatures), pFeatures)) =>
|
||||
val userVector = new DoubleMatrix(uFeatures)
|
||||
val productVector = new DoubleMatrix(pFeatures)
|
||||
Rating(user, product, userVector.dot(productVector))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Figure out what other good bulk prediction methods would look like.
|
||||
// Probably want a way to get the top users for a product or vice-versa.
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue