[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType

to be consistent with other string names in MLlib. This PR also updates the implementation to use vals instead of hardcoded strings. jkbradley leahmcguire

Author: Xiangrui Meng <meng@databricks.com>

Closes #6277 from mengxr/SPARK-7752 and squashes the following commits:

f38b662 [Xiangrui Meng] add another case _ back in test
ae5c66a [Xiangrui Meng] model type -> modelType
711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752
40ae53e [Xiangrui Meng] fix Java test suite
264a814 [Xiangrui Meng] add case _ back
3c456a8 [Xiangrui Meng] update NB user guide
17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings
This commit is contained in:
Xiangrui Meng 2015-05-21 10:30:08 -07:00 committed by Joseph K. Bradley
parent a25c1ab8f0
commit 13348e21b6
4 changed files with 75 additions and 59 deletions

View file

@ -21,7 +21,7 @@ Within that context, each observation is a document and each
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
Feature values must be nonnegative. The model type is selected with an optional parameter Feature values must be nonnegative. The model type is selected with an optional parameter
"Multinomial" or "Bernoulli" with "Multinomial" as the default. "multinomial" or "bernoulli" with "multinomial" as the default.
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of
@ -35,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
multinomial naive Bayes. It takes an RDD of multinomial naive Bayes. It takes an RDD of
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
can be used for evaluation and prediction. can be used for evaluation and prediction.
@ -54,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0) val training = splits(0)
val test = splits(1) val test = splits(1)
val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial")
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
@ -75,6 +75,8 @@ optionally smoothing parameter `lambda` as input, and output a
can be used for evaluation and prediction. can be used for evaluation and prediction.
{% highlight java %} {% highlight java %}
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
@ -82,7 +84,6 @@ import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
JavaRDD<LabeledPoint> training = ... // training set JavaRDD<LabeledPoint> training = ... // training set
JavaRDD<LabeledPoint> test = ... // test set JavaRDD<LabeledPoint> test = ... // test set

View file

@ -25,13 +25,12 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.{DataFrame, SQLContext}
/** /**
* Model for Naive Bayes Classifiers. * Model for Naive Bayes Classifiers.
* *
@ -39,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* @param pi log of class priors, whose dimension is C, number of labels * @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D, * @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features * where D is number of features
* @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
*/ */
class NaiveBayesModel private[mllib] ( class NaiveBayesModel private[mllib] (
val labels: Array[Double], val labels: Array[Double],
@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] (
val modelType: String) val modelType: String)
extends ClassificationModel with Serializable with Saveable { extends ClassificationModel with Serializable with Saveable {
import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
private val piVector = new DenseVector(pi) private val piVector = new DenseVector(pi)
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true) private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true)
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, "Multinomial") this(labels, pi, theta, NaiveBayes.Multinomial)
/** A Java-friendly constructor that takes three Iterable parameters. */ /** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this( private[mllib] def this(
@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] (
theta: JIterable[JIterable[Double]]) = theta: JIterable[JIterable[Double]]) =
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
require(supportedModelTypes.contains(modelType),
s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.")
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function). // application of this condition (in predict function).
private val (thetaMinusNegTheta, negThetaSum) = modelType match { private val (thetaMinusNegTheta, negThetaSum) = modelType match {
case "Multinomial" => (None, None) case Multinomial => (None, None)
case "Bernoulli" => case Bernoulli =>
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
val thetaMinusNegTheta = thetaMatrix.map { value => val thetaMinusNegTheta = thetaMatrix.map { value =>
@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] (
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ => case _ =>
// This should never happen. // This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") throw new UnknownError(s"Invalid modelType: $modelType.")
} }
override def predict(testData: RDD[Vector]): RDD[Double] = { override def predict(testData: RDD[Vector]): RDD[Double] = {
@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = { override def predict(testData: Vector): Double = {
modelType match { modelType match {
case "Multinomial" => case Multinomial =>
val prob = thetaMatrix.multiply(testData) val prob = thetaMatrix.multiply(testData)
BLAS.axpy(1.0, piVector, prob) BLAS.axpy(1.0, piVector, prob)
labels(prob.argmax) labels(prob.argmax)
case "Bernoulli" => case Bernoulli =>
testData.foreachActive { (index, value) => testData.foreachActive { (index, value) =>
if (value != 0.0 && value != 1.0) { if (value != 0.0 && value != 1.0) {
throw new SparkException( throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
} }
} }
val prob = thetaMinusNegTheta.get.multiply(testData) val prob = thetaMinusNegTheta.get.multiply(testData)
@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] (
labels(prob.argmax) labels(prob.argmax)
case _ => case _ =>
// This should never happen. // This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") throw new UnknownError(s"Invalid modelType: $modelType.")
} }
} }
@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
s"($loadedClassName, $version). Supported:\n" + s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)") s" ($classNameV1_0, 1.0)")
} }
assert(model.pi.size == numClasses, assert(model.pi.length == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," + s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements") s" but class priors vector pi had ${model.pi.length} elements")
assert(model.theta.size == numClasses, assert(model.theta.length == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," + s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements") s" but class conditionals array theta had ${model.theta.length} elements")
assert(model.theta.forall(_.size == numFeatures), assert(model.theta.forall(_.length == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," + s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" + s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}") s" ${model.theta.map(_.length).mkString(",")}")
model model
} }
} }
@ -257,9 +261,11 @@ class NaiveBayes private (
private var lambda: Double, private var lambda: Double,
private var modelType: String) extends Serializable with Logging { private var modelType: String) extends Serializable with Logging {
def this(lambda: Double) = this(lambda, "Multinomial") import NaiveBayes.{Bernoulli, Multinomial}
def this() = this(1.0, "Multinomial") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
def this() = this(1.0, NaiveBayes.Multinomial)
/** Set the smoothing parameter. Default: 1.0. */ /** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = { def setLambda(lambda: Double): NaiveBayes = {
@ -272,12 +278,11 @@ class NaiveBayes private (
/** /**
* Set the model type using a string (case-sensitive). * Set the model type using a string (case-sensitive).
* Supported options: "Multinomial" and "Bernoulli". * Supported options: "multinomial" (default) and "bernoulli".
* (default: Multinomial)
*/ */
def setModelType(modelType: String): NaiveBayes = { def setModelType(modelType: String): NaiveBayes = {
require(NaiveBayes.supportedModelTypes.contains(modelType), require(NaiveBayes.supportedModelTypes.contains(modelType),
s"NaiveBayes was created with an unknown ModelType: $modelType") s"NaiveBayes was created with an unknown modelType: $modelType.")
this.modelType = modelType this.modelType = modelType
this this
} }
@ -308,7 +313,7 @@ class NaiveBayes private (
} }
if (!values.forall(v => v == 0.0 || v == 1.0)) { if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException( throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
} }
} }
@ -317,7 +322,7 @@ class NaiveBayes private (
// TODO: similar to reduceByKeyLocally to save one stage. // TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
createCombiner = (v: Vector) => { createCombiner = (v: Vector) => {
if (modelType == "Bernoulli") { if (modelType == Bernoulli) {
requireZeroOneBernoulliValues(v) requireZeroOneBernoulliValues(v)
} else { } else {
requireNonnegativeValues(v) requireNonnegativeValues(v)
@ -352,11 +357,11 @@ class NaiveBayes private (
labels(i) = label labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match { val thetaLogDenom = modelType match {
case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda) case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
case "Bernoulli" => math.log(n + 2.0 * lambda) case Bernoulli => math.log(n + 2.0 * lambda)
case _ => case _ =>
// This should never happen. // This should never happen.
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") throw new UnknownError(s"Invalid modelType: $modelType.")
} }
var j = 0 var j = 0
while (j < numFeatures) { while (j < numFeatures) {
@ -375,8 +380,14 @@ class NaiveBayes private (
*/ */
object NaiveBayes { object NaiveBayes {
/** String name for multinomial model type. */
private[classification] val Multinomial: String = "multinomial"
/** String name for Bernoulli model type. */
private[classification] val Bernoulli: String = "bernoulli"
/* Set of modelTypes that NaiveBayes supports */ /* Set of modelTypes that NaiveBayes supports */
private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
/** /**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
@ -406,7 +417,7 @@ object NaiveBayes {
* @param lambda The smoothing parameter * @param lambda The smoothing parameter
*/ */
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda, "Multinomial").run(input) new NaiveBayes(lambda, Multinomial).run(input)
} }
/** /**
@ -429,7 +440,7 @@ object NaiveBayes {
*/ */
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
require(supportedModelTypes.contains(modelType), require(supportedModelTypes.contains(modelType),
s"NaiveBayes was created with an unknown ModelType: $modelType") s"NaiveBayes was created with an unknown modelType: $modelType.")
new NaiveBayes(lambda, modelType).run(input) new NaiveBayes(lambda, modelType).run(input)
} }

View file

@ -108,7 +108,7 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test @Test
public void testModelTypeSetters() { public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes() NaiveBayes nb = new NaiveBayes()
.setModelType("Bernoulli") .setModelType("bernoulli")
.setModelType("Multinomial"); .setModelType("multinomial");
} }
} }

View file

@ -19,9 +19,8 @@ package org.apache.spark.mllib.classification
import scala.util.Random import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.stats.distributions.{Multinomial => BrzMultinomial} import breeze.stats.distributions.{Multinomial => BrzMultinomial}
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.SparkException import org.apache.spark.SparkException
@ -30,9 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
object NaiveBayesSuite { object NaiveBayesSuite {
import NaiveBayes.{Multinomial, Bernoulli}
private def calcLabel(p: Double, pi: Array[Double]): Int = { private def calcLabel(p: Double, pi: Array[Double]): Int = {
var sum = 0.0 var sum = 0.0
for (j <- 0 until pi.length) { for (j <- 0 until pi.length) {
@ -48,7 +48,7 @@ object NaiveBayesSuite {
theta: Array[Array[Double]], // CXD theta: Array[Array[Double]], // CXD
nPoints: Int, nPoints: Int,
seed: Int, seed: Int,
modelType: String = "Multinomial", modelType: String = Multinomial,
sample: Int = 10): Seq[LabeledPoint] = { sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length val D = theta(0).length
val rnd = new Random(seed) val rnd = new Random(seed)
@ -58,10 +58,10 @@ object NaiveBayesSuite {
for (i <- 0 until nPoints) yield { for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi) val y = calcLabel(rnd.nextDouble(), _pi)
val xi = modelType match { val xi = modelType match {
case "Bernoulli" => Array.tabulate[Double] (D) { j => case Bernoulli => Array.tabulate[Double] (D) { j =>
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
} }
case "Multinomial" => case Multinomial =>
val mult = BrzMultinomial(BDV(_theta(y))) val mult = BrzMultinomial(BDV(_theta(y)))
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@ -70,7 +70,7 @@ object NaiveBayesSuite {
counts.toArray.sortBy(_._1).map(_._2) counts.toArray.sortBy(_._1).map(_._2)
case _ => case _ =>
// This should never happen. // This should never happen.
throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") throw new UnknownError(s"Invalid modelType: $modelType.")
} }
LabeledPoint(y, Vectors.dense(xi)) LabeledPoint(y, Vectors.dense(xi))
@ -79,17 +79,17 @@ object NaiveBayesSuite {
/** Bernoulli NaiveBayes with binary labels, 3 features */ /** Bernoulli NaiveBayes with binary labels, 3 features */
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli)
"Bernoulli")
/** Multinomial NaiveBayes with binary labels, 3 features */ /** Multinomial NaiveBayes with binary labels, 3 features */
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
"Multinomial")
} }
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
import NaiveBayes.{Multinomial, Bernoulli}
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count { val numOfPredictions = predictions.zip(input).count {
case (prediction, expected) => case (prediction, expected) =>
@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
} }
} }
test("model types") {
assert(Multinomial === "multinomial")
assert(Bernoulli === "bernoulli")
}
test("get, set params") { test("get, set params") {
val nb = new NaiveBayes() val nb = new NaiveBayes()
nb.setLambda(2.0) nb.setLambda(2.0)
@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Array(0.10, 0.10, 0.70, 0.10) // label 2 Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log)) ).map(_.map(math.log))
val testData = NaiveBayesSuite.generateNaiveBayesInput( val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial)
pi, theta, nPoints, 42, "Multinomial")
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") val model = NaiveBayes.train(testRDD, 1.0, Multinomial)
validateModelFit(pi, theta, model) validateModelFit(pi, theta, model)
val validationData = NaiveBayesSuite.generateNaiveBayesInput( val validationData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 17, "Multinomial") pi, theta, nPoints, 17, Multinomial)
val validationRDD = sc.parallelize(validationData, 2) val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD. // Test prediction on RDD.
@ -163,15 +167,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log)) ).map(_.map(math.log))
val testData = NaiveBayesSuite.generateNaiveBayesInput( val testData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 45, "Bernoulli") pi, theta, nPoints, 45, Bernoulli)
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") val model = NaiveBayes.train(testRDD, 1.0, Bernoulli)
validateModelFit(pi, theta, model) validateModelFit(pi, theta, model)
val validationData = NaiveBayesSuite.generateNaiveBayesInput( val validationData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 20, "Bernoulli") pi, theta, nPoints, 20, Bernoulli)
val validationRDD = sc.parallelize(validationData, 2) val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD. // Test prediction on RDD.
@ -216,7 +220,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0))) LabeledPoint(1.0, Vectors.dense(0.0)))
intercept[SparkException] { intercept[SparkException] {
NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli") NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli)
} }
val okTrain = Seq( val okTrain = Seq(
@ -235,7 +239,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense(1.0), Vectors.dense(1.0),
Vectors.dense(0.0)) Vectors.dense(0.0))
val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli)
intercept[SparkException] { intercept[SparkException] {
model.predict(sc.makeRDD(badPredict, 2)).collect() model.predict(sc.makeRDD(badPredict, 2)).collect()
} }
@ -275,7 +279,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
assert(model.labels === sameModel.labels) assert(model.labels === sameModel.labels)
assert(model.pi === sameModel.pi) assert(model.pi === sameModel.pi)
assert(model.theta === sameModel.theta) assert(model.theta === sameModel.theta)
assert(model.modelType === "Multinomial") assert(model.modelType === Multinomial)
} finally { } finally {
Utils.deleteRecursively(tempDir) Utils.deleteRecursively(tempDir)
} }