From 900f14f6fad50369aa849922447f60d7cf06cf2f Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 1 Sep 2017 17:32:33 -0700 Subject: [PATCH] [SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns ## What changes were proposed in this pull request? Add test for prediction using the model with all combinations of output columns turned on/off. Make sure the output column values match, presumably by comparing vs. the case with all 3 output columns turned on. ## How was this patch tested? Test updated. Author: WeichenXu Author: WeichenXu Closes #19065 from WeichenXu123/generic_test_for_prob_classifier. --- .../DecisionTreeClassifierSuite.scala | 3 + .../classification/GBTClassifierSuite.scala | 3 + .../LogisticRegressionSuite.scala | 6 ++ .../MultilayerPerceptronClassifierSuite.scala | 2 + .../ml/classification/NaiveBayesSuite.scala | 6 ++ .../ProbabilisticClassifierSuite.scala | 60 +++++++++++++++++++ .../RandomForestClassifierSuite.scala | 2 + 7 files changed, 82 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 918ab27e27..98c879ece6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, "probability prediction mismatch") } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, DecisionTreeClassificationModel](newTree, newData) } test("training with 1-category categorical feature") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 1f79e0d4e6..8000143d4d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, GBTClassificationModel](gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 6bf1253b71..d43c7cdbde 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -502,6 +502,9 @@ class LogisticRegressionSuite resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, LogisticRegressionModel](model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -556,6 +559,9 @@ class LogisticRegressionSuite resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, LogisticRegressionModel](model, smallBinaryDataset) } test("coefficients and intercept methods") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index c294e4ad54..d3141ec708 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, MultilayerPerceptronClassificationModel](model, strongDataset) } test("test model probability") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 3a2be236f1..9730dd68a3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "multinomial") + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, NaiveBayesModel](model, testDataset) } test("Naive Bayes with weighted samples") { @@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, NaiveBayesModel](model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 172c64aab9..4ecd5a0536 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.{Dataset, Row} final class TestProbabilisticClassificationModel( override val uid: String, @@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite { "thresholds" -> Array(0.4, 0.6) ) + /** + * Helper for testing that a ProbabilisticClassificationModel computes + * the same predictions across all combinations of output columns + * (rawPrediction/probability/prediction) turned on/off. Makes sure the + * output column values match by comparing vs. the case with all 3 output + * columns turned on. + */ + def testPredictMethods[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]]( + model: M, testData: Dataset[_]): Unit = { + + val allColModel = model.copy(ParamMap.empty) + .setRawPredictionCol("rawPredictionAll") + .setProbabilityCol("probabilityAll") + .setPredictionCol("predictionAll") + val allColResult = allColModel.transform(testData) + + for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { + for (probabilityCol <- Seq("", "probabilitySingle")) { + for (predictionCol <- Seq("", "predictionSingle")) { + val newModel = model.copy(ParamMap.empty) + .setRawPredictionCol(rawPredictionCol) + .setProbabilityCol(probabilityCol) + .setPredictionCol(predictionCol) + + val result = newModel.transform(allColResult) + + import org.apache.spark.sql.functions._ + + val resultRawPredictionCol = + if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) + val resultProbabilityCol = + if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) + val resultPredictionCol = + if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + + result.select( + resultRawPredictionCol, col("rawPredictionAll"), + resultProbabilityCol, col("probabilityAll"), + resultPredictionCol, col("predictionAll") + ).collect().foreach { + case Row( + rawPredictionSingle: Vector, rawPredictionAll: Vector, + probabilitySingle: Vector, probabilityAll: Vector, + predictionSingle: Double, predictionAll: Double + ) => { + assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3) + assert(probabilitySingle ~== probabilityAll relTol 1E-3) + assert(predictionSingle ~== predictionAll relTol 1E-3) + } + } + } + } + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ca2954d2f3..2cca2e6c04 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,8 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, RandomForestClassificationModel](model, df) } test("Fitting without numClasses in metadata") {