Refactored NaiveBayes

* Minimized shuffle output with mapPartitions.
* Reduced RDD actions from 3 to 1.
This commit is contained in:
Lian, Cheng 2013-12-25 17:15:38 +08:00
parent 3dc655aa19
commit 3bb714eaa3
2 changed files with 41 additions and 28 deletions

View file

@ -48,11 +48,12 @@ class NaiveBayesModel(val weightPerLabel: Array[Double],
}
}
class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
extends Serializable with Logging {
private[this] def vectorAdd(v1: Array[Double], v2: Array[Double]) =
v1.zip(v2).map(pair => pair._1 + pair._2)
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@ -61,29 +62,42 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
* @param D dimension of feature vectors
* @param data RDD of (label, array of features) pairs.
*/
def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = {
val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey()
val countPerLabel = groupedData.mapValues(_.size)
val logDenominator = math.log(data.count() + C * lambda)
val weightPerLabel = countPerLabel.mapValues {
count => math.log(count + lambda) - logDenominator
def run(C: Int, D: Int, data: RDD[LabeledPoint]) = {
val locallyReduced = data.mapPartitions { iterator =>
val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0)
val localSummedObservations =
mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0))
for (LabeledPoint(label, features) <- iterator; i = label.toInt) {
localLabelCounts(i) += 1
localSummedObservations(i) = vectorAdd(localSummedObservations(i), features)
}
for ((label, count) <- localLabelCounts.toIterator) yield {
label -> (count, localSummedObservations(label))
}
}
val summedObservations = groupedData.mapValues(_.reduce {
(lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2)
})
val weightsMatrix = summedObservations.mapValues { weights =>
val sum = weights.sum
val logDenom = math.log(sum + D * lambda)
weights.map(w => math.log(w + lambda) - logDenom)
val reduced = locallyReduced.reduceByKey { (lhs, rhs) =>
(lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2))
}
val labelWeights = weightPerLabel.collect().sorted.map(_._2)
val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2)
new NaiveBayesModel(labelWeights, weightsMat)
val collected = reduced.mapValues { case (count, summed) =>
val labelWeight = math.log(count + lambda)
val logDenom = math.log(summed.sum + D * lambda)
val weights = summed.map(w => math.log(w + lambda) - logDenom)
(count, labelWeight, weights)
}.collectAsMap()
val weightPerLabel = {
val N = collected.values.map(_._1).sum
val logDenom = math.log(N + C * lambda)
collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2)
}
val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2)
new NaiveBayesModel(weightPerLabel, weightMatrix)
}
}

View file

@ -1,6 +1,5 @@
package org.apache.spark.mllib.classification
import scala.collection.JavaConversions._
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
@ -56,12 +55,12 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
}
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count {
val numOfPredictions = predictions.zip(input).count {
case (prediction, expected) =>
prediction != expected.label
}
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)
assert(numOfPredictions < input.length / 5)
}
test("Naive Bayes") {
@ -71,8 +70,8 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
val weightsMatrix = Array(
Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0
Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1
Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
)
Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
)
val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)