[SPARK-23152][ML] - Correctly guard against empty datasets
## What changes were proposed in this pull request? Correctly guard against empty datasets in `org.apache.spark.ml.classification.Classifier` ## How was this patch tested? existing tests Author: Matthew Tovbin <mtovbin@salesforce.com> Closes #20321 from tovbinm/SPARK-23152.
This commit is contained in:
parent
bbb87b350d
commit
840dea64ab
|
@ -109,7 +109,7 @@ abstract class Classifier[
|
|||
case None =>
|
||||
// Get number of classes from dataset itself.
|
||||
val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1)
|
||||
if (maxLabelRow.isEmpty) {
|
||||
if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
|
||||
throw new SparkException("ML algorithm was given empty dataset.")
|
||||
}
|
||||
val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
|
||||
|
|
|
@ -90,6 +90,13 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
assert(e.getMessage.contains("requires integers in range"))
|
||||
}
|
||||
val df3 = getTestData(Seq.empty[Double])
|
||||
withClue("getNumClasses should fail if dataset is empty") {
|
||||
val e: SparkException = intercept[SparkException] {
|
||||
c.getNumClasses(df3)
|
||||
}
|
||||
assert(e.getMessage == "ML algorithm was given empty dataset.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue