[SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns
## What changes were proposed in this pull request? Add test for prediction using the model with all combinations of output columns turned on/off. Make sure the output column values match, presumably by comparing vs. the case with all 3 output columns turned on. ## How was this patch tested? Test updated. Author: WeichenXu <weichen.xu@databricks.com> Author: WeichenXu <WeichenXu123@outlook.com> Closes #19065 from WeichenXu123/generic_test_for_prob_classifier.
This commit is contained in:
parent
aba9492d25
commit
900f14f6fa
|
@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite
|
|||
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
|
||||
"probability prediction mismatch")
|
||||
}
|
||||
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, DecisionTreeClassificationModel](newTree, newData)
|
||||
}
|
||||
|
||||
test("training with 1-category categorical feature") {
|
||||
|
|
|
@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
|
||||
case (pred1, pred2) => assert(pred1 === pred2)
|
||||
}
|
||||
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, GBTClassificationModel](gbtModel, validationDataset)
|
||||
}
|
||||
|
||||
test("GBT parameter stepSize should be in interval (0, 1]") {
|
||||
|
|
|
@ -502,6 +502,9 @@ class LogisticRegressionSuite
|
|||
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
|
||||
case (pred1, pred2) => assert(pred1 === pred2)
|
||||
}
|
||||
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, LogisticRegressionModel](model, smallMultinomialDataset)
|
||||
}
|
||||
|
||||
test("binary logistic regression: Predictor, Classifier methods") {
|
||||
|
@ -556,6 +559,9 @@ class LogisticRegressionSuite
|
|||
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
|
||||
case (pred1, pred2) => assert(pred1 === pred2)
|
||||
}
|
||||
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, LogisticRegressionModel](model, smallBinaryDataset)
|
||||
}
|
||||
|
||||
test("coefficients and intercept methods") {
|
||||
|
|
|
@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite
|
|||
case Row(p: Vector, e: Vector) =>
|
||||
assert(p ~== e absTol 1e-3)
|
||||
}
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
|
||||
}
|
||||
|
||||
test("test model probability") {
|
||||
|
|
|
@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
|||
val featureAndProbabilities = model.transform(validationDataset)
|
||||
.select("features", "probability")
|
||||
validateProbabilities(featureAndProbabilities, model, "multinomial")
|
||||
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, NaiveBayesModel](model, testDataset)
|
||||
}
|
||||
|
||||
test("Naive Bayes with weighted samples") {
|
||||
|
@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
|
|||
val featureAndProbabilities = model.transform(validationDataset)
|
||||
.select("features", "probability")
|
||||
validateProbabilities(featureAndProbabilities, model, "bernoulli")
|
||||
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, NaiveBayesModel](model, testDataset)
|
||||
}
|
||||
|
||||
test("detect negative values") {
|
||||
|
|
|
@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
|
||||
final class TestProbabilisticClassificationModel(
|
||||
override val uid: String,
|
||||
|
@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite {
|
|||
"thresholds" -> Array(0.4, 0.6)
|
||||
)
|
||||
|
||||
/**
|
||||
* Helper for testing that a ProbabilisticClassificationModel computes
|
||||
* the same predictions across all combinations of output columns
|
||||
* (rawPrediction/probability/prediction) turned on/off. Makes sure the
|
||||
* output column values match by comparing vs. the case with all 3 output
|
||||
* columns turned on.
|
||||
*/
|
||||
def testPredictMethods[
|
||||
FeaturesType,
|
||||
M <: ProbabilisticClassificationModel[FeaturesType, M]](
|
||||
model: M, testData: Dataset[_]): Unit = {
|
||||
|
||||
val allColModel = model.copy(ParamMap.empty)
|
||||
.setRawPredictionCol("rawPredictionAll")
|
||||
.setProbabilityCol("probabilityAll")
|
||||
.setPredictionCol("predictionAll")
|
||||
val allColResult = allColModel.transform(testData)
|
||||
|
||||
for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
|
||||
for (probabilityCol <- Seq("", "probabilitySingle")) {
|
||||
for (predictionCol <- Seq("", "predictionSingle")) {
|
||||
val newModel = model.copy(ParamMap.empty)
|
||||
.setRawPredictionCol(rawPredictionCol)
|
||||
.setProbabilityCol(probabilityCol)
|
||||
.setPredictionCol(predictionCol)
|
||||
|
||||
val result = newModel.transform(allColResult)
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
val resultRawPredictionCol =
|
||||
if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
|
||||
val resultProbabilityCol =
|
||||
if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
|
||||
val resultPredictionCol =
|
||||
if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
|
||||
|
||||
result.select(
|
||||
resultRawPredictionCol, col("rawPredictionAll"),
|
||||
resultProbabilityCol, col("probabilityAll"),
|
||||
resultPredictionCol, col("predictionAll")
|
||||
).collect().foreach {
|
||||
case Row(
|
||||
rawPredictionSingle: Vector, rawPredictionAll: Vector,
|
||||
probabilitySingle: Vector, probabilityAll: Vector,
|
||||
predictionSingle: Double, predictionAll: Double
|
||||
) => {
|
||||
assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3)
|
||||
assert(probabilitySingle ~== probabilityAll relTol 1E-3)
|
||||
assert(predictionSingle ~== predictionAll relTol 1E-3)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -155,6 +155,8 @@ class RandomForestClassifierSuite
|
|||
"probability prediction mismatch")
|
||||
assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
|
||||
}
|
||||
ProbabilisticClassifierSuite.testPredictMethods[
|
||||
Vector, RandomForestClassificationModel](model, df)
|
||||
}
|
||||
|
||||
test("Fitting without numClasses in metadata") {
|
||||
|
|
Loading…
Reference in a new issue