Merge pull request #116 from jianpingjwang/master
remove unused variables and fix a bug
This commit is contained in:
commit
44e4205ac5
|
@ -1,6 +1,5 @@
|
||||||
package org.apache.spark.graph.algorithms
|
package org.apache.spark.graph.algorithms
|
||||||
|
|
||||||
import org.apache.spark._
|
|
||||||
import org.apache.spark.rdd._
|
import org.apache.spark.rdd._
|
||||||
import org.apache.spark.graph._
|
import org.apache.spark.graph._
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
@ -10,7 +9,7 @@ class VT ( // vertex type
|
||||||
var v1: RealVector, // v1: p for user node, q for item node
|
var v1: RealVector, // v1: p for user node, q for item node
|
||||||
var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
|
var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
|
||||||
var bias: Double,
|
var bias: Double,
|
||||||
var norm: Double // only for user node
|
var norm: Double // |N(u)|^(-0.5) for user node
|
||||||
) extends Serializable
|
) extends Serializable
|
||||||
|
|
||||||
class Msg ( // message
|
class Msg ( // message
|
||||||
|
@ -20,7 +19,15 @@ class Msg ( // message
|
||||||
) extends Serializable
|
) extends Serializable
|
||||||
|
|
||||||
object Svdpp {
|
object Svdpp {
|
||||||
// implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf
|
/**
|
||||||
|
* Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model",
|
||||||
|
* paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
|
||||||
|
* The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), see the details on page 6.
|
||||||
|
*
|
||||||
|
* @param edges edges for constructing the graph
|
||||||
|
*
|
||||||
|
* @return a graph with vertex attributes containing the trained model
|
||||||
|
*/
|
||||||
|
|
||||||
def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
|
def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
|
||||||
// defalut parameters
|
// defalut parameters
|
||||||
|
@ -33,6 +40,7 @@ object Svdpp {
|
||||||
val gamma6 = 0.005
|
val gamma6 = 0.005
|
||||||
val gamma7 = 0.015
|
val gamma7 = 0.015
|
||||||
|
|
||||||
|
// generate default vertex attribute
|
||||||
def defaultF(rank: Int) = {
|
def defaultF(rank: Int) = {
|
||||||
val v1 = new ArrayRealVector(rank)
|
val v1 = new ArrayRealVector(rank)
|
||||||
val v2 = new ArrayRealVector(rank)
|
val v2 = new ArrayRealVector(rank)
|
||||||
|
@ -44,7 +52,7 @@ object Svdpp {
|
||||||
vd
|
vd
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate initial norm and bias
|
// calculate initial bias and norm
|
||||||
def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
|
def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
|
||||||
assert(et.srcAttr != null && et.dstAttr != null)
|
assert(et.srcAttr != null && et.dstAttr != null)
|
||||||
Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
|
Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
|
||||||
|
@ -67,7 +75,7 @@ object Svdpp {
|
||||||
// make graph
|
// make graph
|
||||||
var g = Graph.fromEdges(edges, defaultF(rank)).cache()
|
var g = Graph.fromEdges(edges, defaultF(rank)).cache()
|
||||||
|
|
||||||
// calculate initial norm and bias
|
// calculate initial bias and norm
|
||||||
val t0 = g.mapReduceTriplets(mapF0, reduceF0)
|
val t0 = g.mapReduceTriplets(mapF0, reduceF0)
|
||||||
g.outerJoinVertices(t0) {updateF0}
|
g.outerJoinVertices(t0) {updateF0}
|
||||||
|
|
||||||
|
@ -89,20 +97,17 @@ object Svdpp {
|
||||||
// phase 2
|
// phase 2
|
||||||
def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
|
def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
|
||||||
assert(et.srcAttr != null && et.dstAttr != null)
|
assert(et.srcAttr != null && et.dstAttr != null)
|
||||||
val usr = et.srcAttr
|
val (usr, itm) = (et.srcAttr, et.dstAttr)
|
||||||
val itm = et.dstAttr
|
val (p, q) = (usr.v1, itm.v1)
|
||||||
var p = usr.v1
|
|
||||||
var q = itm.v1
|
|
||||||
val itmBias = 0.0
|
|
||||||
val usrBias = 0.0
|
|
||||||
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
|
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
|
||||||
pred = math.max(pred, minVal)
|
pred = math.max(pred, minVal)
|
||||||
pred = math.min(pred, maxVal)
|
pred = math.min(pred, maxVal)
|
||||||
val err = et.attr - pred
|
val err = et.attr - pred
|
||||||
val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7))
|
val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7))
|
||||||
val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q
|
val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
|
||||||
val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
|
val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
|
||||||
Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias)))
|
Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)),
|
||||||
|
(et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias)))
|
||||||
}
|
}
|
||||||
def reduceF2(g1: Msg, g2: Msg):Msg = {
|
def reduceF2(g1: Msg, g2: Msg):Msg = {
|
||||||
g1.v1 = g1.v1.add(g2.v1)
|
g1.v1 = g1.v1.add(g2.v1)
|
||||||
|
@ -113,7 +118,7 @@ object Svdpp {
|
||||||
def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
|
def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
|
||||||
if (msg.isDefined) {
|
if (msg.isDefined) {
|
||||||
vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
|
vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
|
||||||
if (vid % 2 == 1) { // item node update y
|
if (vid % 2 == 1) { // item nodes update y
|
||||||
vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
|
vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
|
||||||
}
|
}
|
||||||
vd.bias += msg.get.bias*gamma1
|
vd.bias += msg.get.bias*gamma1
|
||||||
|
@ -122,10 +127,10 @@ object Svdpp {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i <- 0 until maxIters) {
|
for (i <- 0 until maxIters) {
|
||||||
// phase 1
|
// phase 1, calculate v2 for user nodes
|
||||||
val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
|
val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
|
||||||
g.outerJoinVertices(t1) {updateF1}
|
g.outerJoinVertices(t1) {updateF1}
|
||||||
// phase 2
|
// phase 2, update p for user nodes and q, y for item nodes
|
||||||
val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
|
val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
|
||||||
g.outerJoinVertices(t2) {updateF2}
|
g.outerJoinVertices(t2) {updateF2}
|
||||||
}
|
}
|
||||||
|
@ -133,12 +138,8 @@ object Svdpp {
|
||||||
// calculate error on training set
|
// calculate error on training set
|
||||||
def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
|
def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
|
||||||
assert(et.srcAttr != null && et.dstAttr != null)
|
assert(et.srcAttr != null && et.dstAttr != null)
|
||||||
val usr = et.srcAttr
|
val (usr, itm) = (et.srcAttr, et.dstAttr)
|
||||||
val itm = et.dstAttr
|
val (p, q) = (usr.v1, itm.v1)
|
||||||
var p = usr.v1
|
|
||||||
var q = itm.v1
|
|
||||||
val itmBias = 0.0
|
|
||||||
val usrBias = 0.0
|
|
||||||
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
|
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
|
||||||
pred = math.max(pred, minVal)
|
pred = math.max(pred, minVal)
|
||||||
pred = math.min(pred, maxVal)
|
pred = math.min(pred, maxVal)
|
||||||
|
@ -146,7 +147,7 @@ object Svdpp {
|
||||||
Iterator((et.dstId, err))
|
Iterator((et.dstId, err))
|
||||||
}
|
}
|
||||||
def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
|
def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
|
||||||
if (msg.isDefined && vid % 2 == 1) { // item sum up the errors
|
if (msg.isDefined && vid % 2 == 1) { // item nodes sum up the errors
|
||||||
vd.norm = msg.get
|
vd.norm = msg.get
|
||||||
}
|
}
|
||||||
vd
|
vd
|
||||||
|
|
Loading…
Reference in a new issue