From 92c0eaf348b42b3479610da0be761013f9d81c54 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Wed, 24 Aug 2016 10:16:58 +0100 Subject: [PATCH] [SPARK-17086][ML] Fix InvalidArgumentException issue in QuantileDiscretizer when some quantiles are duplicated ## What changes were proposed in this pull request? In cases when QuantileDiscretizerSuite is called upon a numeric array with duplicated elements, we will take the unique elements generated from approxQuantiles as input for Bucketizer. ## How was this patch tested? An unit test is added in QuantileDiscretizerSuite QuantileDiscretizer.fit will throw an illegal exception when calling setSplits on a list of splits with duplicated elements. Bucketizer.setSplits should only accept either a numeric vector of two or more unique cut points, although that may produce less number of buckets than requested. Signed-off-by: VinceShieh Author: VinceShieh Closes #14747 from VinceShieh/SPARK-17086. --- .../ml/feature/QuantileDiscretizer.scala | 7 ++++++- .../ml/feature/QuantileDiscretizerSuite.scala | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 558a7bbf0a..e09800877c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val bucketizer = new Bucketizer(uid).setSplits(splits) + val distinctSplits = splits.distinct + if (splits.length != distinctSplits.length) { + log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + + s" buckets as a result.") + } + val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b73dbd6232..18f1e89ee8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -52,6 +52,25 @@ class QuantileDiscretizerSuite "Bucket sizes are not within expected relative error tolerance.") } + test("Test Bucketizer on duplicated splits") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 12 + val numBuckets = 5 + val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) + .map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val observedNumBuckets = result.select("result").distinct.count + assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets, + "Observed number of buckets are not within expected range.") + } + test("Test transform method on unseen data") { val spark = this.spark import spark.implicits._