[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType
to be consistent with other string names in MLlib. This PR also updates the implementation to use vals instead of hardcoded strings. jkbradley leahmcguire Author: Xiangrui Meng <meng@databricks.com> Closes #6277 from mengxr/SPARK-7752 and squashes the following commits: f38b662 [Xiangrui Meng] add another case _ back in test ae5c66a [Xiangrui Meng] model type -> modelType 711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752 40ae53e [Xiangrui Meng] fix Java test suite 264a814 [Xiangrui Meng] add case _ back 3c456a8 [Xiangrui Meng] update NB user guide 17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings
This commit is contained in:
parent
a25c1ab8f0
commit
13348e21b6
|
@ -21,7 +21,7 @@ Within that context, each observation is a document and each
|
||||||
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
|
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
|
||||||
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
|
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
|
||||||
Feature values must be nonnegative. The model type is selected with an optional parameter
|
Feature values must be nonnegative. The model type is selected with an optional parameter
|
||||||
"Multinomial" or "Bernoulli" with "Multinomial" as the default.
|
"multinomial" or "bernoulli" with "multinomial" as the default.
|
||||||
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
|
[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
|
||||||
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
|
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
|
||||||
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of
|
vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of
|
||||||
|
@ -35,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach
|
||||||
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
|
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
|
||||||
multinomial naive Bayes. It takes an RDD of
|
multinomial naive Bayes. It takes an RDD of
|
||||||
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
|
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
|
||||||
smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a
|
smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a
|
||||||
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
|
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
|
||||||
can be used for evaluation and prediction.
|
can be used for evaluation and prediction.
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
|
||||||
val training = splits(0)
|
val training = splits(0)
|
||||||
val test = splits(1)
|
val test = splits(1)
|
||||||
|
|
||||||
val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial")
|
val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial")
|
||||||
|
|
||||||
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
|
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
|
||||||
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
|
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
|
||||||
|
@ -75,6 +75,8 @@ optionally smoothing parameter `lambda` as input, and output a
|
||||||
can be used for evaluation and prediction.
|
can be used for evaluation and prediction.
|
||||||
|
|
||||||
{% highlight java %}
|
{% highlight java %}
|
||||||
|
import scala.Tuple2;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.function.Function;
|
import org.apache.spark.api.java.function.Function;
|
||||||
|
@ -82,7 +84,6 @@ import org.apache.spark.api.java.function.PairFunction;
|
||||||
import org.apache.spark.mllib.classification.NaiveBayes;
|
import org.apache.spark.mllib.classification.NaiveBayes;
|
||||||
import org.apache.spark.mllib.classification.NaiveBayesModel;
|
import org.apache.spark.mllib.classification.NaiveBayesModel;
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import scala.Tuple2;
|
|
||||||
|
|
||||||
JavaRDD<LabeledPoint> training = ... // training set
|
JavaRDD<LabeledPoint> training = ... // training set
|
||||||
JavaRDD<LabeledPoint> test = ... // test set
|
JavaRDD<LabeledPoint> test = ... // test set
|
||||||
|
|
|
@ -25,13 +25,12 @@ import org.json4s.JsonDSL._
|
||||||
import org.json4s.jackson.JsonMethods._
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
import org.apache.spark.{Logging, SparkContext, SparkException}
|
import org.apache.spark.{Logging, SparkContext, SparkException}
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.util.{Loader, Saveable}
|
import org.apache.spark.mllib.util.{Loader, Saveable}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model for Naive Bayes Classifiers.
|
* Model for Naive Bayes Classifiers.
|
||||||
*
|
*
|
||||||
|
@ -39,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||||
* @param pi log of class priors, whose dimension is C, number of labels
|
* @param pi log of class priors, whose dimension is C, number of labels
|
||||||
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
|
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
|
||||||
* where D is number of features
|
* where D is number of features
|
||||||
* @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
|
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
|
||||||
*/
|
*/
|
||||||
class NaiveBayesModel private[mllib] (
|
class NaiveBayesModel private[mllib] (
|
||||||
val labels: Array[Double],
|
val labels: Array[Double],
|
||||||
|
@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] (
|
||||||
val modelType: String)
|
val modelType: String)
|
||||||
extends ClassificationModel with Serializable with Saveable {
|
extends ClassificationModel with Serializable with Saveable {
|
||||||
|
|
||||||
|
import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
|
||||||
|
|
||||||
private val piVector = new DenseVector(pi)
|
private val piVector = new DenseVector(pi)
|
||||||
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
|
private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true)
|
||||||
|
|
||||||
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
|
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
|
||||||
this(labels, pi, theta, "Multinomial")
|
this(labels, pi, theta, NaiveBayes.Multinomial)
|
||||||
|
|
||||||
/** A Java-friendly constructor that takes three Iterable parameters. */
|
/** A Java-friendly constructor that takes three Iterable parameters. */
|
||||||
private[mllib] def this(
|
private[mllib] def this(
|
||||||
|
@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] (
|
||||||
theta: JIterable[JIterable[Double]]) =
|
theta: JIterable[JIterable[Double]]) =
|
||||||
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
|
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
|
||||||
|
|
||||||
|
require(supportedModelTypes.contains(modelType),
|
||||||
|
s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.")
|
||||||
|
|
||||||
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
|
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
|
||||||
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
|
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
|
||||||
// application of this condition (in predict function).
|
// application of this condition (in predict function).
|
||||||
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
|
private val (thetaMinusNegTheta, negThetaSum) = modelType match {
|
||||||
case "Multinomial" => (None, None)
|
case Multinomial => (None, None)
|
||||||
case "Bernoulli" =>
|
case Bernoulli =>
|
||||||
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
|
val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
|
||||||
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
|
val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
|
||||||
val thetaMinusNegTheta = thetaMatrix.map { value =>
|
val thetaMinusNegTheta = thetaMatrix.map { value =>
|
||||||
|
@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] (
|
||||||
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
|
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
|
||||||
case _ =>
|
case _ =>
|
||||||
// This should never happen.
|
// This should never happen.
|
||||||
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
|
throw new UnknownError(s"Invalid modelType: $modelType.")
|
||||||
}
|
}
|
||||||
|
|
||||||
override def predict(testData: RDD[Vector]): RDD[Double] = {
|
override def predict(testData: RDD[Vector]): RDD[Double] = {
|
||||||
|
@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] (
|
||||||
|
|
||||||
override def predict(testData: Vector): Double = {
|
override def predict(testData: Vector): Double = {
|
||||||
modelType match {
|
modelType match {
|
||||||
case "Multinomial" =>
|
case Multinomial =>
|
||||||
val prob = thetaMatrix.multiply(testData)
|
val prob = thetaMatrix.multiply(testData)
|
||||||
BLAS.axpy(1.0, piVector, prob)
|
BLAS.axpy(1.0, piVector, prob)
|
||||||
labels(prob.argmax)
|
labels(prob.argmax)
|
||||||
case "Bernoulli" =>
|
case Bernoulli =>
|
||||||
testData.foreachActive { (index, value) =>
|
testData.foreachActive { (index, value) =>
|
||||||
if (value != 0.0 && value != 1.0) {
|
if (value != 0.0 && value != 1.0) {
|
||||||
throw new SparkException(
|
throw new SparkException(
|
||||||
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
|
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val prob = thetaMinusNegTheta.get.multiply(testData)
|
val prob = thetaMinusNegTheta.get.multiply(testData)
|
||||||
|
@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] (
|
||||||
labels(prob.argmax)
|
labels(prob.argmax)
|
||||||
case _ =>
|
case _ =>
|
||||||
// This should never happen.
|
// This should never happen.
|
||||||
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
|
throw new UnknownError(s"Invalid modelType: $modelType.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
||||||
s"($loadedClassName, $version). Supported:\n" +
|
s"($loadedClassName, $version). Supported:\n" +
|
||||||
s" ($classNameV1_0, 1.0)")
|
s" ($classNameV1_0, 1.0)")
|
||||||
}
|
}
|
||||||
assert(model.pi.size == numClasses,
|
assert(model.pi.length == numClasses,
|
||||||
s"NaiveBayesModel.load expected $numClasses classes," +
|
s"NaiveBayesModel.load expected $numClasses classes," +
|
||||||
s" but class priors vector pi had ${model.pi.size} elements")
|
s" but class priors vector pi had ${model.pi.length} elements")
|
||||||
assert(model.theta.size == numClasses,
|
assert(model.theta.length == numClasses,
|
||||||
s"NaiveBayesModel.load expected $numClasses classes," +
|
s"NaiveBayesModel.load expected $numClasses classes," +
|
||||||
s" but class conditionals array theta had ${model.theta.size} elements")
|
s" but class conditionals array theta had ${model.theta.length} elements")
|
||||||
assert(model.theta.forall(_.size == numFeatures),
|
assert(model.theta.forall(_.length == numFeatures),
|
||||||
s"NaiveBayesModel.load expected $numFeatures features," +
|
s"NaiveBayesModel.load expected $numFeatures features," +
|
||||||
s" but class conditionals array theta had elements of size:" +
|
s" but class conditionals array theta had elements of size:" +
|
||||||
s" ${model.theta.map(_.size).mkString(",")}")
|
s" ${model.theta.map(_.length).mkString(",")}")
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -257,9 +261,11 @@ class NaiveBayes private (
|
||||||
private var lambda: Double,
|
private var lambda: Double,
|
||||||
private var modelType: String) extends Serializable with Logging {
|
private var modelType: String) extends Serializable with Logging {
|
||||||
|
|
||||||
def this(lambda: Double) = this(lambda, "Multinomial")
|
import NaiveBayes.{Bernoulli, Multinomial}
|
||||||
|
|
||||||
def this() = this(1.0, "Multinomial")
|
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
|
||||||
|
|
||||||
|
def this() = this(1.0, NaiveBayes.Multinomial)
|
||||||
|
|
||||||
/** Set the smoothing parameter. Default: 1.0. */
|
/** Set the smoothing parameter. Default: 1.0. */
|
||||||
def setLambda(lambda: Double): NaiveBayes = {
|
def setLambda(lambda: Double): NaiveBayes = {
|
||||||
|
@ -272,12 +278,11 @@ class NaiveBayes private (
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the model type using a string (case-sensitive).
|
* Set the model type using a string (case-sensitive).
|
||||||
* Supported options: "Multinomial" and "Bernoulli".
|
* Supported options: "multinomial" (default) and "bernoulli".
|
||||||
* (default: Multinomial)
|
|
||||||
*/
|
*/
|
||||||
def setModelType(modelType: String): NaiveBayes = {
|
def setModelType(modelType: String): NaiveBayes = {
|
||||||
require(NaiveBayes.supportedModelTypes.contains(modelType),
|
require(NaiveBayes.supportedModelTypes.contains(modelType),
|
||||||
s"NaiveBayes was created with an unknown ModelType: $modelType")
|
s"NaiveBayes was created with an unknown modelType: $modelType.")
|
||||||
this.modelType = modelType
|
this.modelType = modelType
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
@ -308,7 +313,7 @@ class NaiveBayes private (
|
||||||
}
|
}
|
||||||
if (!values.forall(v => v == 0.0 || v == 1.0)) {
|
if (!values.forall(v => v == 0.0 || v == 1.0)) {
|
||||||
throw new SparkException(
|
throw new SparkException(
|
||||||
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
|
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,7 +322,7 @@ class NaiveBayes private (
|
||||||
// TODO: similar to reduceByKeyLocally to save one stage.
|
// TODO: similar to reduceByKeyLocally to save one stage.
|
||||||
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
|
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
|
||||||
createCombiner = (v: Vector) => {
|
createCombiner = (v: Vector) => {
|
||||||
if (modelType == "Bernoulli") {
|
if (modelType == Bernoulli) {
|
||||||
requireZeroOneBernoulliValues(v)
|
requireZeroOneBernoulliValues(v)
|
||||||
} else {
|
} else {
|
||||||
requireNonnegativeValues(v)
|
requireNonnegativeValues(v)
|
||||||
|
@ -352,11 +357,11 @@ class NaiveBayes private (
|
||||||
labels(i) = label
|
labels(i) = label
|
||||||
pi(i) = math.log(n + lambda) - piLogDenom
|
pi(i) = math.log(n + lambda) - piLogDenom
|
||||||
val thetaLogDenom = modelType match {
|
val thetaLogDenom = modelType match {
|
||||||
case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
|
case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
|
||||||
case "Bernoulli" => math.log(n + 2.0 * lambda)
|
case Bernoulli => math.log(n + 2.0 * lambda)
|
||||||
case _ =>
|
case _ =>
|
||||||
// This should never happen.
|
// This should never happen.
|
||||||
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
|
throw new UnknownError(s"Invalid modelType: $modelType.")
|
||||||
}
|
}
|
||||||
var j = 0
|
var j = 0
|
||||||
while (j < numFeatures) {
|
while (j < numFeatures) {
|
||||||
|
@ -375,8 +380,14 @@ class NaiveBayes private (
|
||||||
*/
|
*/
|
||||||
object NaiveBayes {
|
object NaiveBayes {
|
||||||
|
|
||||||
|
/** String name for multinomial model type. */
|
||||||
|
private[classification] val Multinomial: String = "multinomial"
|
||||||
|
|
||||||
|
/** String name for Bernoulli model type. */
|
||||||
|
private[classification] val Bernoulli: String = "bernoulli"
|
||||||
|
|
||||||
/* Set of modelTypes that NaiveBayes supports */
|
/* Set of modelTypes that NaiveBayes supports */
|
||||||
private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
|
private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
|
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
|
||||||
|
@ -406,7 +417,7 @@ object NaiveBayes {
|
||||||
* @param lambda The smoothing parameter
|
* @param lambda The smoothing parameter
|
||||||
*/
|
*/
|
||||||
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
|
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
|
||||||
new NaiveBayes(lambda, "Multinomial").run(input)
|
new NaiveBayes(lambda, Multinomial).run(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -429,7 +440,7 @@ object NaiveBayes {
|
||||||
*/
|
*/
|
||||||
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
|
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
|
||||||
require(supportedModelTypes.contains(modelType),
|
require(supportedModelTypes.contains(modelType),
|
||||||
s"NaiveBayes was created with an unknown ModelType: $modelType")
|
s"NaiveBayes was created with an unknown modelType: $modelType.")
|
||||||
new NaiveBayes(lambda, modelType).run(input)
|
new NaiveBayes(lambda, modelType).run(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -108,7 +108,7 @@ public class JavaNaiveBayesSuite implements Serializable {
|
||||||
@Test
|
@Test
|
||||||
public void testModelTypeSetters() {
|
public void testModelTypeSetters() {
|
||||||
NaiveBayes nb = new NaiveBayes()
|
NaiveBayes nb = new NaiveBayes()
|
||||||
.setModelType("Bernoulli")
|
.setModelType("bernoulli")
|
||||||
.setModelType("Multinomial");
|
.setModelType("multinomial");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,8 @@ package org.apache.spark.mllib.classification
|
||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
|
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
|
||||||
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
|
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
|
||||||
|
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
|
@ -30,9 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
|
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
|
||||||
object NaiveBayesSuite {
|
object NaiveBayesSuite {
|
||||||
|
|
||||||
|
import NaiveBayes.{Multinomial, Bernoulli}
|
||||||
|
|
||||||
private def calcLabel(p: Double, pi: Array[Double]): Int = {
|
private def calcLabel(p: Double, pi: Array[Double]): Int = {
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
for (j <- 0 until pi.length) {
|
for (j <- 0 until pi.length) {
|
||||||
|
@ -48,7 +48,7 @@ object NaiveBayesSuite {
|
||||||
theta: Array[Array[Double]], // CXD
|
theta: Array[Array[Double]], // CXD
|
||||||
nPoints: Int,
|
nPoints: Int,
|
||||||
seed: Int,
|
seed: Int,
|
||||||
modelType: String = "Multinomial",
|
modelType: String = Multinomial,
|
||||||
sample: Int = 10): Seq[LabeledPoint] = {
|
sample: Int = 10): Seq[LabeledPoint] = {
|
||||||
val D = theta(0).length
|
val D = theta(0).length
|
||||||
val rnd = new Random(seed)
|
val rnd = new Random(seed)
|
||||||
|
@ -58,10 +58,10 @@ object NaiveBayesSuite {
|
||||||
for (i <- 0 until nPoints) yield {
|
for (i <- 0 until nPoints) yield {
|
||||||
val y = calcLabel(rnd.nextDouble(), _pi)
|
val y = calcLabel(rnd.nextDouble(), _pi)
|
||||||
val xi = modelType match {
|
val xi = modelType match {
|
||||||
case "Bernoulli" => Array.tabulate[Double] (D) { j =>
|
case Bernoulli => Array.tabulate[Double] (D) { j =>
|
||||||
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
|
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
|
||||||
}
|
}
|
||||||
case "Multinomial" =>
|
case Multinomial =>
|
||||||
val mult = BrzMultinomial(BDV(_theta(y)))
|
val mult = BrzMultinomial(BDV(_theta(y)))
|
||||||
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
|
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
|
||||||
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
|
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
|
||||||
|
@ -70,7 +70,7 @@ object NaiveBayesSuite {
|
||||||
counts.toArray.sortBy(_._1).map(_._2)
|
counts.toArray.sortBy(_._1).map(_._2)
|
||||||
case _ =>
|
case _ =>
|
||||||
// This should never happen.
|
// This should never happen.
|
||||||
throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
|
throw new UnknownError(s"Invalid modelType: $modelType.")
|
||||||
}
|
}
|
||||||
|
|
||||||
LabeledPoint(y, Vectors.dense(xi))
|
LabeledPoint(y, Vectors.dense(xi))
|
||||||
|
@ -79,17 +79,17 @@ object NaiveBayesSuite {
|
||||||
|
|
||||||
/** Bernoulli NaiveBayes with binary labels, 3 features */
|
/** Bernoulli NaiveBayes with binary labels, 3 features */
|
||||||
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
|
private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
|
||||||
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
|
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli)
|
||||||
"Bernoulli")
|
|
||||||
|
|
||||||
/** Multinomial NaiveBayes with binary labels, 3 features */
|
/** Multinomial NaiveBayes with binary labels, 3 features */
|
||||||
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
|
private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
|
||||||
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
|
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
|
||||||
"Multinomial")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
|
import NaiveBayes.{Multinomial, Bernoulli}
|
||||||
|
|
||||||
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
|
||||||
val numOfPredictions = predictions.zip(input).count {
|
val numOfPredictions = predictions.zip(input).count {
|
||||||
case (prediction, expected) =>
|
case (prediction, expected) =>
|
||||||
|
@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("model types") {
|
||||||
|
assert(Multinomial === "multinomial")
|
||||||
|
assert(Bernoulli === "bernoulli")
|
||||||
|
}
|
||||||
|
|
||||||
test("get, set params") {
|
test("get, set params") {
|
||||||
val nb = new NaiveBayes()
|
val nb = new NaiveBayes()
|
||||||
nb.setLambda(2.0)
|
nb.setLambda(2.0)
|
||||||
|
@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
Array(0.10, 0.10, 0.70, 0.10) // label 2
|
Array(0.10, 0.10, 0.70, 0.10) // label 2
|
||||||
).map(_.map(math.log))
|
).map(_.map(math.log))
|
||||||
|
|
||||||
val testData = NaiveBayesSuite.generateNaiveBayesInput(
|
val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial)
|
||||||
pi, theta, nPoints, 42, "Multinomial")
|
|
||||||
val testRDD = sc.parallelize(testData, 2)
|
val testRDD = sc.parallelize(testData, 2)
|
||||||
testRDD.cache()
|
testRDD.cache()
|
||||||
|
|
||||||
val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
|
val model = NaiveBayes.train(testRDD, 1.0, Multinomial)
|
||||||
validateModelFit(pi, theta, model)
|
validateModelFit(pi, theta, model)
|
||||||
|
|
||||||
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
|
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
|
||||||
pi, theta, nPoints, 17, "Multinomial")
|
pi, theta, nPoints, 17, Multinomial)
|
||||||
val validationRDD = sc.parallelize(validationData, 2)
|
val validationRDD = sc.parallelize(validationData, 2)
|
||||||
|
|
||||||
// Test prediction on RDD.
|
// Test prediction on RDD.
|
||||||
|
@ -163,15 +167,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
).map(_.map(math.log))
|
).map(_.map(math.log))
|
||||||
|
|
||||||
val testData = NaiveBayesSuite.generateNaiveBayesInput(
|
val testData = NaiveBayesSuite.generateNaiveBayesInput(
|
||||||
pi, theta, nPoints, 45, "Bernoulli")
|
pi, theta, nPoints, 45, Bernoulli)
|
||||||
val testRDD = sc.parallelize(testData, 2)
|
val testRDD = sc.parallelize(testData, 2)
|
||||||
testRDD.cache()
|
testRDD.cache()
|
||||||
|
|
||||||
val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
|
val model = NaiveBayes.train(testRDD, 1.0, Bernoulli)
|
||||||
validateModelFit(pi, theta, model)
|
validateModelFit(pi, theta, model)
|
||||||
|
|
||||||
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
|
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
|
||||||
pi, theta, nPoints, 20, "Bernoulli")
|
pi, theta, nPoints, 20, Bernoulli)
|
||||||
val validationRDD = sc.parallelize(validationData, 2)
|
val validationRDD = sc.parallelize(validationData, 2)
|
||||||
|
|
||||||
// Test prediction on RDD.
|
// Test prediction on RDD.
|
||||||
|
@ -216,7 +220,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0)))
|
LabeledPoint(1.0, Vectors.dense(0.0)))
|
||||||
|
|
||||||
intercept[SparkException] {
|
intercept[SparkException] {
|
||||||
NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
|
NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli)
|
||||||
}
|
}
|
||||||
|
|
||||||
val okTrain = Seq(
|
val okTrain = Seq(
|
||||||
|
@ -235,7 +239,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
Vectors.dense(1.0),
|
Vectors.dense(1.0),
|
||||||
Vectors.dense(0.0))
|
Vectors.dense(0.0))
|
||||||
|
|
||||||
val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
|
val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli)
|
||||||
intercept[SparkException] {
|
intercept[SparkException] {
|
||||||
model.predict(sc.makeRDD(badPredict, 2)).collect()
|
model.predict(sc.makeRDD(badPredict, 2)).collect()
|
||||||
}
|
}
|
||||||
|
@ -275,7 +279,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
assert(model.labels === sameModel.labels)
|
assert(model.labels === sameModel.labels)
|
||||||
assert(model.pi === sameModel.pi)
|
assert(model.pi === sameModel.pi)
|
||||||
assert(model.theta === sameModel.theta)
|
assert(model.theta === sameModel.theta)
|
||||||
assert(model.modelType === "Multinomial")
|
assert(model.modelType === Multinomial)
|
||||||
} finally {
|
} finally {
|
||||||
Utils.deleteRecursively(tempDir)
|
Utils.deleteRecursively(tempDir)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue