[SPARK-17086][ML] Fix InvalidArgumentException issue in QuantileDiscretizer when some quantiles are duplicated
## What changes were proposed in this pull request? In cases when QuantileDiscretizerSuite is called upon a numeric array with duplicated elements, we will take the unique elements generated from approxQuantiles as input for Bucketizer. ## How was this patch tested? An unit test is added in QuantileDiscretizerSuite QuantileDiscretizer.fit will throw an illegal exception when calling setSplits on a list of splits with duplicated elements. Bucketizer.setSplits should only accept either a numeric vector of two or more unique cut points, although that may produce less number of buckets than requested. Signed-off-by: VinceShieh <vincent.xieintel.com> Author: VinceShieh <vincent.xie@intel.com> Closes #14747 from VinceShieh/SPARK-17086.
This commit is contained in:
parent
673a80d223
commit
92c0eaf348
|
@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
|
|||
splits(0) = Double.NegativeInfinity
|
||||
splits(splits.length - 1) = Double.PositiveInfinity
|
||||
|
||||
val bucketizer = new Bucketizer(uid).setSplits(splits)
|
||||
val distinctSplits = splits.distinct
|
||||
if (splits.length != distinctSplits.length) {
|
||||
log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
|
||||
s" buckets as a result.")
|
||||
}
|
||||
val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
|
||||
copyValues(bucketizer.setParent(this))
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,25 @@ class QuantileDiscretizerSuite
|
|||
"Bucket sizes are not within expected relative error tolerance.")
|
||||
}
|
||||
|
||||
test("Test Bucketizer on duplicated splits") {
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
||||
val datasetSize = 12
|
||||
val numBuckets = 5
|
||||
val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0))
|
||||
.map(Tuple1.apply).toDF("input")
|
||||
val discretizer = new QuantileDiscretizer()
|
||||
.setInputCol("input")
|
||||
.setOutputCol("result")
|
||||
.setNumBuckets(numBuckets)
|
||||
val result = discretizer.fit(df).transform(df)
|
||||
|
||||
val observedNumBuckets = result.select("result").distinct.count
|
||||
assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets,
|
||||
"Observed number of buckets are not within expected range.")
|
||||
}
|
||||
|
||||
test("Test transform method on unseen data") {
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
|
Loading…
Reference in a new issue