Added Java unit test, data, and main method for Naive Bayes
Also fixes mains of a few other algorithms to print the final model
This commit is contained in:
parent
4c28a2bad8
commit
f00e949f84
6
mllib/data/sample_naive_bayes_data.txt
Normal file
6
mllib/data/sample_naive_bayes_data.txt
Normal file
|
@ -0,0 +1,6 @@
|
|||
0, 1 0 0
|
||||
0, 2 0 0
|
||||
1, 0 1 0
|
||||
1, 0 2 0
|
||||
2, 0 0 1
|
||||
2, 0 0 2
|
|
@ -97,7 +97,7 @@ object LogisticRegressionWithSGD {
|
|||
* @param numIterations Number of iterations of gradient descent to run.
|
||||
* @param stepSize Step size to be used for each iteration of gradient descent.
|
||||
* @param miniBatchFraction Fraction of data to be used per iteration.
|
||||
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
|
||||
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
|
||||
* the number of features in the data.
|
||||
*/
|
||||
def train(
|
||||
|
@ -183,6 +183,8 @@ object LogisticRegressionWithSGD {
|
|||
val sc = new SparkContext(args(0), "LogisticRegression")
|
||||
val data = MLUtils.loadLabeledData(sc, args(1))
|
||||
val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
|
||||
println("Weights: " + model.weights.mkString("[", ", ", "]"))
|
||||
println("Intercept: " + model.intercept)
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
|
|
@ -21,9 +21,10 @@ import scala.collection.mutable
|
|||
|
||||
import org.jblas.DoubleMatrix
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.{SparkContext, Logging}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
||||
/**
|
||||
* Model for Naive Bayes Classifiers.
|
||||
|
@ -144,4 +145,22 @@ object NaiveBayes {
|
|||
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
|
||||
new NaiveBayes(lambda).run(input)
|
||||
}
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length != 2 && args.length != 3) {
|
||||
println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
|
||||
System.exit(1)
|
||||
}
|
||||
val sc = new SparkContext(args(0), "NaiveBayes")
|
||||
val data = MLUtils.loadLabeledData(sc, args(1))
|
||||
val model = if (args.length == 2) {
|
||||
NaiveBayes.train(data)
|
||||
} else {
|
||||
NaiveBayes.train(data, args(2).toDouble)
|
||||
}
|
||||
println("Pi: " + model.pi.mkString("[", ", ", "]"))
|
||||
println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]"))
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -183,6 +183,8 @@ object SVMWithSGD {
|
|||
val sc = new SparkContext(args(0), "SVM")
|
||||
val data = MLUtils.loadLabeledData(sc, args(1))
|
||||
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
|
||||
println("Weights: " + model.weights.mkString("[", ", ", "]"))
|
||||
println("Intercept: " + model.intercept)
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
|
|
@ -121,7 +121,7 @@ object LassoWithSGD {
|
|||
* @param stepSize Step size to be used for each iteration of gradient descent.
|
||||
* @param regParam Regularization parameter.
|
||||
* @param miniBatchFraction Fraction of data to be used per iteration.
|
||||
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
|
||||
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
|
||||
* the number of features in the data.
|
||||
*/
|
||||
def train(
|
||||
|
@ -205,6 +205,8 @@ object LassoWithSGD {
|
|||
val sc = new SparkContext(args(0), "Lasso")
|
||||
val data = MLUtils.loadLabeledData(sc, args(1))
|
||||
val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
|
||||
println("Weights: " + model.weights.mkString("[", ", ", "]"))
|
||||
println("Intercept: " + model.intercept)
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
|
|
@ -162,6 +162,8 @@ object LinearRegressionWithSGD {
|
|||
val sc = new SparkContext(args(0), "LinearRegression")
|
||||
val data = MLUtils.loadLabeledData(sc, args(1))
|
||||
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
|
||||
println("Weights: " + model.weights.mkString("[", ", ", "]"))
|
||||
println("Intercept: " + model.intercept)
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ object RidgeRegressionWithSGD {
|
|||
* @param stepSize Step size to be used for each iteration of gradient descent.
|
||||
* @param regParam Regularization parameter.
|
||||
* @param miniBatchFraction Fraction of data to be used per iteration.
|
||||
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
|
||||
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
|
||||
* the number of features in the data.
|
||||
*/
|
||||
def train(
|
||||
|
@ -208,6 +208,8 @@ object RidgeRegressionWithSGD {
|
|||
val data = MLUtils.loadLabeledData(sc, args(1))
|
||||
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
|
||||
args(3).toDouble)
|
||||
println("Weights: " + model.weights.mkString("[", ", ", "]"))
|
||||
println("Intercept: " + model.intercept)
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
package org.apache.spark.mllib.classification;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class JavaNaiveBayesSuite implements Serializable {
|
||||
private transient JavaSparkContext sc;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
sc.stop();
|
||||
sc = null;
|
||||
System.clearProperty("spark.driver.port");
|
||||
}
|
||||
|
||||
private static final List<LabeledPoint> POINTS = Arrays.asList(
|
||||
new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
|
||||
new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
|
||||
new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
|
||||
new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
|
||||
new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
|
||||
new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
|
||||
);
|
||||
|
||||
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
|
||||
int correct = 0;
|
||||
for (LabeledPoint p: points) {
|
||||
if (model.predict(p.features()) == p.label()) {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
return correct;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void runUsingConstructor() {
|
||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
|
||||
|
||||
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
|
||||
NaiveBayesModel model = nb.run(testRDD.rdd());
|
||||
|
||||
int numAccurate = validatePrediction(POINTS, model);
|
||||
Assert.assertEquals(POINTS.size(), numAccurate);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void runUsingStaticMethods() {
|
||||
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
|
||||
|
||||
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
|
||||
int numAccurate1 = validatePrediction(POINTS, model1);
|
||||
Assert.assertEquals(POINTS.size(), numAccurate1);
|
||||
|
||||
NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
|
||||
int numAccurate2 = validatePrediction(POINTS, model2);
|
||||
Assert.assertEquals(POINTS.size(), numAccurate2);
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue