diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala index 4ddf0b1fd5..26b999f4cf 100644 --- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala +++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala @@ -1,6 +1,5 @@ package org.apache.spark.graph.algorithms -import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.graph._ import scala.util.Random @@ -10,7 +9,7 @@ class VT ( // vertex type var v1: RealVector, // v1: p for user node, q for item node var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node var bias: Double, - var norm: Double // only for user node + var norm: Double // |N(u)|^(-0.5) for user node ) extends Serializable class Msg ( // message @@ -20,7 +19,15 @@ class Msg ( // message ) extends Serializable object Svdpp { - // implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf + /** + * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model", + * paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. + * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), see the details on page 6. + * + * @param edges edges for constructing the graph + * + * @return a graph with vertex attributes containing the trained model + */ def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = { // defalut parameters @@ -32,7 +39,8 @@ object Svdpp { val gamma2 = 0.007 val gamma6 = 0.005 val gamma7 = 0.015 - + + // generate default vertex attribute def defaultF(rank: Int) = { val v1 = new ArrayRealVector(rank) val v2 = new ArrayRealVector(rank) @@ -44,7 +52,7 @@ object Svdpp { vd } - // calculate initial norm and bias + // calculate initial bias and norm def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = { assert(et.srcAttr != null && et.dstAttr != null) Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))) @@ -67,10 +75,10 @@ object Svdpp { // make graph var g = Graph.fromEdges(edges, defaultF(rank)).cache() - // calculate initial norm and bias + // calculate initial bias and norm val t0 = g.mapReduceTriplets(mapF0, reduceF0) - g.outerJoinVertices(t0) {updateF0} - + g.outerJoinVertices(t0) {updateF0} + // phase 1 def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = { assert(et.srcAttr != null && et.dstAttr != null) @@ -89,21 +97,18 @@ object Svdpp { // phase 2 def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = { assert(et.srcAttr != null && et.dstAttr != null) - val usr = et.srcAttr - val itm = et.dstAttr - var p = usr.v1 - var q = itm.v1 - val itmBias = 0.0 - val usrBias = 0.0 + val (usr, itm) = (et.srcAttr, et.dstAttr) + val (p, q) = (usr.v1, itm.v1) var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) pred = math.max(pred, minVal) pred = math.min(pred, maxVal) val err = et.attr - pred - val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7)) - val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q - val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7)) - Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias))) - } + val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) + val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7)) + val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7)) + Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)), + (et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias))) + } def reduceF2(g1: Msg, g2: Msg):Msg = { g1.v1 = g1.v1.add(g2.v1) g1.v2 = g1.v2.add(g2.v2) @@ -113,7 +118,7 @@ object Svdpp { def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = { if (msg.isDefined) { vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2)) - if (vid % 2 == 1) { // item node update y + if (vid % 2 == 1) { // item nodes update y vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2)) } vd.bias += msg.get.bias*gamma1 @@ -122,23 +127,19 @@ object Svdpp { } for (i <- 0 until maxIters) { - // phase 1 + // phase 1, calculate v2 for user nodes val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1) - g.outerJoinVertices(t1) {updateF1} - // phase 2 + g.outerJoinVertices(t1) {updateF1} + // phase 2, update p for user nodes and q, y for item nodes val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2) g.outerJoinVertices(t2) {updateF2} } - + // calculate error on training set def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = { assert(et.srcAttr != null && et.dstAttr != null) - val usr = et.srcAttr - val itm = et.dstAttr - var p = usr.v1 - var q = itm.v1 - val itmBias = 0.0 - val usrBias = 0.0 + val (usr, itm) = (et.srcAttr, et.dstAttr) + val (p, q) = (usr.v1, itm.v1) var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) pred = math.max(pred, minVal) pred = math.min(pred, maxVal) @@ -146,7 +147,7 @@ object Svdpp { Iterator((et.dstId, err)) } def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = { - if (msg.isDefined && vid % 2 == 1) { // item sum up the errors + if (msg.isDefined && vid % 2 == 1) { // item nodes sum up the errors vd.norm = msg.get } vd