commit
e68cdb1b82
|
@ -5,18 +5,15 @@ import org.apache.spark.graph._
|
|||
import scala.util.Random
|
||||
import org.apache.commons.math.linear._
|
||||
|
||||
class VT ( // vertex type
|
||||
var v1: RealVector, // v1: p for user node, q for item node
|
||||
var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
|
||||
var bias: Double,
|
||||
var norm: Double // |N(u)|^(-0.5) for user node
|
||||
) extends Serializable
|
||||
|
||||
class Msg ( // message
|
||||
var v1: RealVector,
|
||||
var v2: RealVector,
|
||||
var bias: Double
|
||||
) extends Serializable
|
||||
class SvdppConf( // Svdpp parameters
|
||||
var rank: Int,
|
||||
var maxIters: Int,
|
||||
var minVal: Double,
|
||||
var maxVal: Double,
|
||||
var gamma1: Double,
|
||||
var gamma2: Double,
|
||||
var gamma6: Double,
|
||||
var gamma7: Double) extends Serializable
|
||||
|
||||
object Svdpp {
|
||||
/**
|
||||
|
@ -24,136 +21,83 @@ object Svdpp {
|
|||
* paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
|
||||
* The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), see the details on page 6.
|
||||
*
|
||||
* @param edges edges for constructing the graph
|
||||
* @param edges edges for constructing the graph
|
||||
*
|
||||
* @param conf Svdpp parameters
|
||||
*
|
||||
* @return a graph with vertex attributes containing the trained model
|
||||
*/
|
||||
|
||||
def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
|
||||
// defalut parameters
|
||||
val rank = 10
|
||||
val maxIters = 20
|
||||
val minVal = 0.0
|
||||
val maxVal = 5.0
|
||||
val gamma1 = 0.007
|
||||
val gamma2 = 0.007
|
||||
val gamma6 = 0.005
|
||||
val gamma7 = 0.015
|
||||
def run(edges: RDD[Edge[Double]], conf: SvdppConf): (Graph[(RealVector, RealVector, Double, Double), Double], Double) = {
|
||||
|
||||
// generate default vertex attribute
|
||||
def defaultF(rank: Int) = {
|
||||
def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = {
|
||||
val v1 = new ArrayRealVector(rank)
|
||||
val v2 = new ArrayRealVector(rank)
|
||||
for (i <- 0 until rank) {
|
||||
v1.setEntry(i, Random.nextDouble)
|
||||
v2.setEntry(i, Random.nextDouble)
|
||||
}
|
||||
var vd = new VT(v1, v2, 0.0, 0.0)
|
||||
vd
|
||||
}
|
||||
|
||||
// calculate initial bias and norm
|
||||
def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
|
||||
assert(et.srcAttr != null && et.dstAttr != null)
|
||||
Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
|
||||
}
|
||||
def reduceF0(g1: (Long, Double), g2: (Long, Double)) = {
|
||||
(g1._1 + g2._1, g1._2 + g2._2)
|
||||
}
|
||||
def updateF0(vid: Vid, vd: VT, msg: Option[(Long, Double)]) = {
|
||||
if (msg.isDefined) {
|
||||
vd.bias = msg.get._2 / msg.get._1
|
||||
vd.norm = 1.0 / scala.math.sqrt(msg.get._1)
|
||||
}
|
||||
vd
|
||||
(v1, v2, 0.0, 0.0)
|
||||
}
|
||||
|
||||
// calculate global rating mean
|
||||
val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
|
||||
val u = rs / rc // global rating mean
|
||||
val u = rs / rc
|
||||
|
||||
// make graph
|
||||
var g = Graph.fromEdges(edges, defaultF(rank)).cache()
|
||||
// construct graph
|
||||
var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
|
||||
|
||||
// calculate initial bias and norm
|
||||
val t0 = g.mapReduceTriplets(mapF0, reduceF0)
|
||||
g.outerJoinVertices(t0) {updateF0}
|
||||
|
||||
// phase 1
|
||||
def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = {
|
||||
assert(et.srcAttr != null && et.dstAttr != null)
|
||||
Iterator((et.srcId, et.dstAttr.v2)) // sum up y of connected item nodes
|
||||
}
|
||||
def reduceF1(g1: RealVector, g2: RealVector) = {
|
||||
g1.add(g2)
|
||||
}
|
||||
def updateF1(vid: Vid, vd: VT, msg: Option[RealVector]) = {
|
||||
if (msg.isDefined) {
|
||||
vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm)) // pu + |N(u)|^(-0.5)*sum(y)
|
||||
}
|
||||
vd
|
||||
var t0 = g.mapReduceTriplets(et =>
|
||||
Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2))
|
||||
g = g.outerJoinVertices(t0) { (vid: Vid, vd: (RealVector, RealVector, Double, Double), msg: Option[(Long, Double)]) =>
|
||||
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
|
||||
}
|
||||
|
||||
// phase 2
|
||||
def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
|
||||
assert(et.srcAttr != null && et.dstAttr != null)
|
||||
def mapTrainF(conf: SvdppConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
|
||||
: Iterator[(Vid, (RealVector, RealVector, Double))] = {
|
||||
val (usr, itm) = (et.srcAttr, et.dstAttr)
|
||||
val (p, q) = (usr.v1, itm.v1)
|
||||
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
|
||||
pred = math.max(pred, minVal)
|
||||
pred = math.min(pred, maxVal)
|
||||
val (p, q) = (usr._1, itm._1)
|
||||
var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
|
||||
pred = math.max(pred, conf.minVal)
|
||||
pred = math.min(pred, conf.maxVal)
|
||||
val err = et.attr - pred
|
||||
val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7))
|
||||
val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
|
||||
val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
|
||||
Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)),
|
||||
(et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias)))
|
||||
}
|
||||
def reduceF2(g1: Msg, g2: Msg):Msg = {
|
||||
g1.v1 = g1.v1.add(g2.v1)
|
||||
g1.v2 = g1.v2.add(g2.v2)
|
||||
g1.bias += g2.bias
|
||||
g1
|
||||
}
|
||||
def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
|
||||
if (msg.isDefined) {
|
||||
vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
|
||||
if (vid % 2 == 1) { // item nodes update y
|
||||
vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
|
||||
}
|
||||
vd.bias += msg.get.bias*gamma1
|
||||
}
|
||||
vd
|
||||
val updateP = ((q.mapMultiply(err)).subtract(p.mapMultiply(conf.gamma7))).mapMultiply(conf.gamma2)
|
||||
val updateQ = ((usr._2.mapMultiply(err)).subtract(q.mapMultiply(conf.gamma7))).mapMultiply(conf.gamma2)
|
||||
val updateY = ((q.mapMultiply(err * usr._4)).subtract((itm._2).mapMultiply(conf.gamma7))).mapMultiply(conf.gamma2)
|
||||
Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)),
|
||||
(et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)))
|
||||
}
|
||||
|
||||
for (i <- 0 until maxIters) {
|
||||
// phase 1, calculate v2 for user nodes
|
||||
val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
|
||||
g.outerJoinVertices(t1) {updateF1}
|
||||
for (i <- 0 until conf.maxIters) {
|
||||
// phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes
|
||||
var t1 = g.mapReduceTriplets(et => Iterator((et.srcId, et.dstAttr._2)), (g1: RealVector, g2: RealVector) => g1.add(g2))
|
||||
g = g.outerJoinVertices(t1) { (vid: Vid, vd: (RealVector, RealVector, Double, Double), msg: Option[RealVector]) =>
|
||||
if (msg.isDefined) (vd._1, vd._1.add(msg.get.mapMultiply(vd._4)), vd._3, vd._4) else vd
|
||||
}
|
||||
// phase 2, update p for user nodes and q, y for item nodes
|
||||
val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
|
||||
g.outerJoinVertices(t2) {updateF2}
|
||||
val t2 = g.mapReduceTriplets(mapTrainF(conf, u), (g1: (RealVector, RealVector, Double), g2: (RealVector, RealVector, Double)) =>
|
||||
(g1._1.add(g2._1), g1._2.add(g2._2), g1._3 + g2._3))
|
||||
g = g.outerJoinVertices(t2) { (vid: Vid, vd: (RealVector, RealVector, Double, Double), msg: Option[(RealVector, RealVector, Double)]) =>
|
||||
(vd._1.add(msg.get._1), vd._2.add(msg.get._2), vd._3 + msg.get._3, vd._4)
|
||||
}
|
||||
}
|
||||
|
||||
// calculate error on training set
|
||||
def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
|
||||
assert(et.srcAttr != null && et.dstAttr != null)
|
||||
def mapTestF(conf: SvdppConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]): Iterator[(Vid, Double)] = {
|
||||
val (usr, itm) = (et.srcAttr, et.dstAttr)
|
||||
val (p, q) = (usr.v1, itm.v1)
|
||||
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
|
||||
pred = math.max(pred, minVal)
|
||||
pred = math.min(pred, maxVal)
|
||||
val err = (et.attr - pred)*(et.attr - pred)
|
||||
val (p, q) = (usr._1, itm._1)
|
||||
var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
|
||||
pred = math.max(pred, conf.minVal)
|
||||
pred = math.min(pred, conf.maxVal)
|
||||
val err = (et.attr - pred) * (et.attr - pred)
|
||||
Iterator((et.dstId, err))
|
||||
}
|
||||
def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
|
||||
if (msg.isDefined && vid % 2 == 1) { // item nodes sum up the errors
|
||||
vd.norm = msg.get
|
||||
}
|
||||
vd
|
||||
val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2)
|
||||
g = g.outerJoinVertices(t3) { (vid: Vid, vd: (RealVector, RealVector, Double, Double), msg: Option[Double]) =>
|
||||
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
|
||||
}
|
||||
val t3: VertexRDD[Double] = g.mapReduceTriplets(mapF3, _ + _)
|
||||
g.outerJoinVertices(t3) {updateF3}
|
||||
g
|
||||
(g, u)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,16 +13,17 @@ class SvdppSuite extends FunSuite with LocalSparkContext {
|
|||
|
||||
test("Test SVD++ with mean square error on training set") {
|
||||
withSpark { sc =>
|
||||
val SvdppErr = 0.01
|
||||
val SvdppErr = 8.0
|
||||
val edges = sc.textFile("mllib/data/als/test.data").map { line =>
|
||||
val fields = line.split(",")
|
||||
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
|
||||
}
|
||||
val graph = Svdpp.run(edges)
|
||||
val conf = new SvdppConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
|
||||
var (graph, u) = Svdpp.run(edges, conf)
|
||||
val err = graph.vertices.collect.map{ case (vid, vd) =>
|
||||
if (vid % 2 == 1) { vd.norm } else { 0.0 }
|
||||
if (vid % 2 == 1) vd._4 else 0.0
|
||||
}.reduce(_ + _) / graph.triplets.collect.size
|
||||
assert(err < SvdppErr)
|
||||
assert(err <= SvdppErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue