[SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns

## What changes were proposed in this pull request?

Add test for prediction using the model with all combinations of output columns turned on/off.
Make sure the output column values match, presumably by comparing vs. the case with all 3 output columns turned on.

## How was this patch tested?

Test updated.

Author: WeichenXu <weichen.xu@databricks.com>
Author: WeichenXu <WeichenXu123@outlook.com>

Closes #19065 from WeichenXu123/generic_test_for_prob_classifier.
This commit is contained in:
WeichenXu 2017-09-01 17:32:33 -07:00 committed by Joseph K. Bradley
parent aba9492d25
commit 900f14f6fa
7 changed files with 82 additions and 0 deletions

View file

@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
}
ProbabilisticClassifierSuite.testPredictMethods[
Vector, DecisionTreeClassificationModel](newTree, newData)
}
test("training with 1-category categorical feature") {

View file

@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
ProbabilisticClassifierSuite.testPredictMethods[
Vector, GBTClassificationModel](gbtModel, validationDataset)
}
test("GBT parameter stepSize should be in interval (0, 1]") {

View file

@ -502,6 +502,9 @@ class LogisticRegressionSuite
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
ProbabilisticClassifierSuite.testPredictMethods[
Vector, LogisticRegressionModel](model, smallMultinomialDataset)
}
test("binary logistic regression: Predictor, Classifier methods") {
@ -556,6 +559,9 @@ class LogisticRegressionSuite
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
ProbabilisticClassifierSuite.testPredictMethods[
Vector, LogisticRegressionModel](model, smallBinaryDataset)
}
test("coefficients and intercept methods") {

View file

@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite
case Row(p: Vector, e: Vector) =>
assert(p ~== e absTol 1e-3)
}
ProbabilisticClassifierSuite.testPredictMethods[
Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
}
test("test model probability") {

View file

@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "multinomial")
ProbabilisticClassifierSuite.testPredictMethods[
Vector, NaiveBayesModel](model, testDataset)
}
test("Naive Bayes with weighted samples") {
@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
ProbabilisticClassifierSuite.testPredictMethods[
Vector, NaiveBayesModel](model, testDataset)
}
test("detect negative values") {

View file

@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.sql.{Dataset, Row}
final class TestProbabilisticClassificationModel(
override val uid: String,
@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite {
"thresholds" -> Array(0.4, 0.6)
)
/**
* Helper for testing that a ProbabilisticClassificationModel computes
* the same predictions across all combinations of output columns
* (rawPrediction/probability/prediction) turned on/off. Makes sure the
* output column values match by comparing vs. the case with all 3 output
* columns turned on.
*/
def testPredictMethods[
FeaturesType,
M <: ProbabilisticClassificationModel[FeaturesType, M]](
model: M, testData: Dataset[_]): Unit = {
val allColModel = model.copy(ParamMap.empty)
.setRawPredictionCol("rawPredictionAll")
.setProbabilityCol("probabilityAll")
.setPredictionCol("predictionAll")
val allColResult = allColModel.transform(testData)
for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
for (probabilityCol <- Seq("", "probabilitySingle")) {
for (predictionCol <- Seq("", "predictionSingle")) {
val newModel = model.copy(ParamMap.empty)
.setRawPredictionCol(rawPredictionCol)
.setProbabilityCol(probabilityCol)
.setPredictionCol(predictionCol)
val result = newModel.transform(allColResult)
import org.apache.spark.sql.functions._
val resultRawPredictionCol =
if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
val resultProbabilityCol =
if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
val resultPredictionCol =
if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
result.select(
resultRawPredictionCol, col("rawPredictionAll"),
resultProbabilityCol, col("probabilityAll"),
resultPredictionCol, col("predictionAll")
).collect().foreach {
case Row(
rawPredictionSingle: Vector, rawPredictionAll: Vector,
probabilitySingle: Vector, probabilityAll: Vector,
predictionSingle: Double, predictionAll: Double
) => {
assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3)
assert(probabilitySingle ~== probabilityAll relTol 1E-3)
assert(predictionSingle ~== predictionAll relTol 1E-3)
}
}
}
}
}
}
}

View file

@ -155,6 +155,8 @@ class RandomForestClassifierSuite
"probability prediction mismatch")
assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
}
ProbabilisticClassifierSuite.testPredictMethods[
Vector, RandomForestClassificationModel](model, df)
}
test("Fitting without numClasses in metadata") {