[SPARK-13089][ML] [Doc] spark.ml Naive Bayes user guide and examples
jira: https://issues.apache.org/jira/browse/SPARK-13089 Add section in ml-classification.md for NaiveBayes DataFrame-based API, plus example code (using include_example to clip code from examples/ folder files). Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11015 from hhbyyh/naiveBayesDoc.
This commit is contained in:
parent
fcdd69260e
commit
781df49983
|
@ -302,6 +302,40 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe
|
|||
</div>
|
||||
</div>
|
||||
|
||||
## Naive Bayes
|
||||
|
||||
[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple
|
||||
probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence
|
||||
assumptions between the features. The spark.ml implementation currently supports both [multinomial
|
||||
naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
|
||||
and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
|
||||
More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib).
|
||||
|
||||
**Example**
|
||||
|
||||
<div class="codetabs">
|
||||
<div data-lang="scala" markdown="1">
|
||||
|
||||
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details.
|
||||
|
||||
{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %}
|
||||
</div>
|
||||
|
||||
<div data-lang="java" markdown="1">
|
||||
|
||||
Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details.
|
||||
|
||||
{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %}
|
||||
</div>
|
||||
|
||||
<div data-lang="python" markdown="1">
|
||||
|
||||
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details.
|
||||
|
||||
{% include_example python/ml/naive_bayes_example.py %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
# Regression
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml;
|
||||
|
||||
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
// $example on$
|
||||
import org.apache.spark.ml.classification.NaiveBayes;
|
||||
import org.apache.spark.ml.classification.NaiveBayesModel;
|
||||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
// $example off$
|
||||
|
||||
/**
|
||||
* An example for Naive Bayes Classification.
|
||||
*/
|
||||
public class JavaNaiveBayesExample {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample");
|
||||
JavaSparkContext jsc = new JavaSparkContext(conf);
|
||||
SQLContext jsql = new SQLContext(jsc);
|
||||
|
||||
// $example on$
|
||||
// Load training data
|
||||
Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
|
||||
// Split the data into train and test
|
||||
Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
|
||||
Dataset<Row> train = splits[0];
|
||||
Dataset<Row> test = splits[1];
|
||||
|
||||
// create the trainer and set its parameters
|
||||
NaiveBayes nb = new NaiveBayes();
|
||||
// train the model
|
||||
NaiveBayesModel model = nb.fit(train);
|
||||
// compute precision on the test set
|
||||
Dataset<Row> result = model.transform(test);
|
||||
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
|
||||
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
|
||||
.setMetricName("precision");
|
||||
System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
|
||||
// $example off$
|
||||
|
||||
jsc.stop();
|
||||
}
|
||||
}
|
53
examples/src/main/python/ml/naive_bayes_example.py
Normal file
53
examples/src/main/python/ml/naive_bayes_example.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from pyspark import SparkContext
|
||||
from pyspark.sql import SQLContext
|
||||
# $example on$
|
||||
from pyspark.ml.classification import NaiveBayes
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
# $example off$
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
sc = SparkContext(appName="naive_bayes_example")
|
||||
sqlContext = SQLContext(sc)
|
||||
|
||||
# $example on$
|
||||
# Load training data
|
||||
data = sqlContext.read.format("libsvm") \
|
||||
.load("data/mllib/sample_libsvm_data.txt")
|
||||
# Split the data into train and test
|
||||
splits = data.randomSplit([0.6, 0.4], 1234)
|
||||
train = splits[0]
|
||||
test = splits[1]
|
||||
|
||||
# create the trainer and set its parameters
|
||||
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
|
||||
|
||||
# train the model
|
||||
model = nb.fit(train)
|
||||
# compute precision on the test set
|
||||
result = model.transform(test)
|
||||
predictionAndLabels = result.select("prediction", "label")
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="precision")
|
||||
print("Precision:" + str(evaluator.evaluate(predictionAndLabels)))
|
||||
# $example off$
|
||||
|
||||
sc.stop()
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// scalastyle:off println
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
// $example on$
|
||||
import org.apache.spark.ml.classification.{NaiveBayes}
|
||||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
|
||||
// $example off$
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
||||
object NaiveBayesExample {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val conf = new SparkConf().setAppName("NaiveBayesExample")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
// $example on$
|
||||
// Load the data stored in LIBSVM format as a DataFrame.
|
||||
val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
|
||||
|
||||
// Split the data into training and test sets (30% held out for testing)
|
||||
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
|
||||
|
||||
// Train a NaiveBayes model.
|
||||
val model = new NaiveBayes()
|
||||
.fit(trainingData)
|
||||
|
||||
// Select example rows to display.
|
||||
val predictions = model.transform(testData)
|
||||
predictions.show()
|
||||
|
||||
// Select (prediction, true label) and compute test error
|
||||
val evaluator = new MulticlassClassificationEvaluator()
|
||||
.setLabelCol("label")
|
||||
.setPredictionCol("prediction")
|
||||
.setMetricName("precision")
|
||||
val precision = evaluator.evaluate(predictions)
|
||||
println("Precision:" + precision)
|
||||
// $example off$
|
||||
}
|
||||
}
|
||||
// scalastyle:on println
|
Loading…
Reference in a new issue