2014-04-22 14:20:47 -04:00
|
|
|
---
|
|
|
|
layout: global
|
2014-05-18 20:00:57 -04:00
|
|
|
title: Naive Bayes - MLlib
|
|
|
|
displayTitle: <a href="mllib-guide.html">MLlib</a> - Naive Bayes
|
2014-04-22 14:20:47 -04:00
|
|
|
---
|
|
|
|
|
|
|
|
Naive Bayes is a simple multiclass classification algorithm with the assumption of independence
|
|
|
|
between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to
|
|
|
|
the training data, it computes the conditional probability distribution of each feature given label,
|
|
|
|
and then it applies Bayes' theorem to compute the conditional probability distribution of label
|
2014-05-06 23:07:22 -04:00
|
|
|
given an observation and use it for prediction. For more details, please visit the Wikipedia page
|
2014-04-22 14:20:47 -04:00
|
|
|
[Naive Bayes classifier](http://en.wikipedia.org/wiki/Naive_Bayes_classifier).
|
|
|
|
|
|
|
|
In MLlib, we implemented multinomial naive Bayes, which is typically used for document
|
|
|
|
classification. Within that context, each observation is a document, each feature represents a term,
|
2014-05-06 23:07:22 -04:00
|
|
|
whose value is the frequency of the term. For its formulation, please visit the Wikipedia page
|
|
|
|
[Multinomial Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
|
2014-04-22 14:20:47 -04:00
|
|
|
or the section
|
|
|
|
[Naive Bayes text classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
|
|
|
|
from the book Introduction to Information
|
|
|
|
Retrieval. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
|
|
|
|
setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature
|
|
|
|
vectors are usually sparse. Please supply sparse vectors as input to take advantage of
|
|
|
|
sparsity. Since the training data is only used once, it is not necessary to cache it.
|
|
|
|
|
|
|
|
## Examples
|
|
|
|
|
|
|
|
<div class="codetabs">
|
|
|
|
<div data-lang="scala" markdown="1">
|
|
|
|
|
2014-05-18 20:00:57 -04:00
|
|
|
[NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
|
2014-04-22 14:20:47 -04:00
|
|
|
multinomial naive Bayes. It takes an RDD of
|
2014-05-18 20:00:57 -04:00
|
|
|
[LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional
|
2014-04-22 14:20:47 -04:00
|
|
|
smoothing parameter `lambda` as input, and output a
|
2014-05-18 20:00:57 -04:00
|
|
|
[NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which
|
2014-04-22 14:20:47 -04:00
|
|
|
can be used for evaluation and prediction.
|
|
|
|
|
|
|
|
{% highlight scala %}
|
|
|
|
import org.apache.spark.mllib.classification.NaiveBayes
|
2014-05-06 23:07:22 -04:00
|
|
|
import org.apache.spark.mllib.linalg.Vectors
|
|
|
|
import org.apache.spark.mllib.regression.LabeledPoint
|
|
|
|
|
|
|
|
val data = sc.textFile("mllib/data/sample_naive_bayes_data.txt")
|
|
|
|
val parsedData = data.map { line =>
|
|
|
|
val parts = line.split(',')
|
|
|
|
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
|
|
|
|
}
|
|
|
|
// Split data into training (60%) and test (40%).
|
|
|
|
val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
|
|
|
|
val training = splits(0)
|
|
|
|
val test = splits(1)
|
2014-04-22 14:20:47 -04:00
|
|
|
|
|
|
|
val model = NaiveBayes.train(training, lambda = 1.0)
|
|
|
|
val prediction = model.predict(test.map(_.features))
|
|
|
|
|
|
|
|
val predictionAndLabel = prediction.zip(test.map(_.label))
|
|
|
|
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
|
|
|
|
{% endhighlight %}
|
|
|
|
</div>
|
|
|
|
|
|
|
|
<div data-lang="java" markdown="1">
|
|
|
|
|
2014-05-18 20:00:57 -04:00
|
|
|
[NaiveBayes](api/java/org/apache/spark/mllib/classification/NaiveBayes.html) implements
|
2014-04-22 14:20:47 -04:00
|
|
|
multinomial naive Bayes. It takes a Scala RDD of
|
2014-05-18 20:00:57 -04:00
|
|
|
[LabeledPoint](api/java/org/apache/spark/mllib/regression/LabeledPoint.html) and an
|
2014-04-22 14:20:47 -04:00
|
|
|
optionally smoothing parameter `lambda` as input, and output a
|
2014-05-18 20:00:57 -04:00
|
|
|
[NaiveBayesModel](api/java/org/apache/spark/mllib/classification/NaiveBayesModel.html), which
|
2014-04-22 14:20:47 -04:00
|
|
|
can be used for evaluation and prediction.
|
|
|
|
|
|
|
|
{% highlight java %}
|
2014-05-06 23:07:22 -04:00
|
|
|
import org.apache.spark.api.java.JavaPairRDD;
|
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
|
import org.apache.spark.api.java.function.Function;
|
2014-04-22 14:20:47 -04:00
|
|
|
import org.apache.spark.mllib.classification.NaiveBayes;
|
2014-05-06 23:07:22 -04:00
|
|
|
import org.apache.spark.mllib.classification.NaiveBayesModel;
|
|
|
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
|
|
|
import scala.Tuple2;
|
2014-04-22 14:20:47 -04:00
|
|
|
|
|
|
|
JavaRDD<LabeledPoint> training = ... // training set
|
|
|
|
JavaRDD<LabeledPoint> test = ... // test set
|
|
|
|
|
2014-05-06 23:07:22 -04:00
|
|
|
final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
|
2014-04-22 14:20:47 -04:00
|
|
|
|
2014-05-06 23:07:22 -04:00
|
|
|
JavaRDD<Double> prediction =
|
|
|
|
test.map(new Function<LabeledPoint, Double>() {
|
|
|
|
@Override public Double call(LabeledPoint p) {
|
|
|
|
return model.predict(p.features());
|
2014-04-22 14:20:47 -04:00
|
|
|
}
|
2014-05-06 23:07:22 -04:00
|
|
|
});
|
2014-04-22 14:20:47 -04:00
|
|
|
JavaPairRDD<Double, Double> predictionAndLabel =
|
|
|
|
prediction.zip(test.map(new Function<LabeledPoint, Double>() {
|
2014-05-06 23:07:22 -04:00
|
|
|
@Override public Double call(LabeledPoint p) {
|
2014-04-22 14:20:47 -04:00
|
|
|
return p.label();
|
|
|
|
}
|
2014-05-06 23:07:22 -04:00
|
|
|
}));
|
2014-04-22 14:20:47 -04:00
|
|
|
double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
|
2014-05-06 23:07:22 -04:00
|
|
|
@Override public Boolean call(Tuple2<Double, Double> pl) {
|
2014-04-22 14:20:47 -04:00
|
|
|
return pl._1() == pl._2();
|
|
|
|
}
|
2014-05-06 23:07:22 -04:00
|
|
|
}).count() / test.count();
|
2014-04-22 14:20:47 -04:00
|
|
|
{% endhighlight %}
|
|
|
|
</div>
|
|
|
|
|
|
|
|
<div data-lang="python" markdown="1">
|
|
|
|
|
2014-05-18 20:00:57 -04:00
|
|
|
[NaiveBayes](api/python/pyspark.mllib.classification.NaiveBayes-class.html) implements multinomial
|
2014-04-22 14:20:47 -04:00
|
|
|
naive Bayes. It takes an RDD of
|
2014-05-18 20:00:57 -04:00
|
|
|
[LabeledPoint](api/python/pyspark.mllib.regression.LabeledPoint-class.html) and an optionally
|
2014-04-22 14:20:47 -04:00
|
|
|
smoothing parameter `lambda` as input, and output a
|
2014-05-18 20:00:57 -04:00
|
|
|
[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be
|
2014-04-22 14:20:47 -04:00
|
|
|
used for evaluation and prediction.
|
|
|
|
|
2014-05-06 23:07:22 -04:00
|
|
|
<!-- TODO: Make Python's example consistent with Scala's and Java's. -->
|
2014-04-22 14:20:47 -04:00
|
|
|
{% highlight python %}
|
|
|
|
from pyspark.mllib.regression import LabeledPoint
|
|
|
|
from pyspark.mllib.classification import NaiveBayes
|
|
|
|
|
|
|
|
# an RDD of LabeledPoint
|
|
|
|
data = sc.parallelize([
|
|
|
|
LabeledPoint(0.0, [0.0, 0.0])
|
|
|
|
... # more labeled points
|
|
|
|
])
|
|
|
|
|
|
|
|
# Train a naive Bayes model.
|
|
|
|
model = NaiveBayes.train(data, 1.0)
|
|
|
|
|
|
|
|
# Make prediction.
|
|
|
|
prediction = model.predict([0.0, 0.0])
|
|
|
|
{% endhighlight %}
|
|
|
|
|
|
|
|
</div>
|
|
|
|
</div>
|