[SPARK-13444][MLLIB] QuantileDiscretizer chooses bad splits on large DataFrames
## What changes were proposed in this pull request? Change line 113 of QuantileDiscretizer.scala to `val requiredSamples = math.max(numBins * numBins, 10000.0)` so that `requiredSamples` is a `Double`. This will fix the division in line 114 which currently results in zero if `requiredSamples < dataset.count` ## How was the this patch tested? Manual tests. I was having a problems using QuantileDiscretizer with my a dataset and after making this change QuantileDiscretizer behaves as expected. Author: Oliver Pierson <ocp@gatech.edu> Author: Oliver Pierson <opierson@umd.edu> Closes #11319 from oliverpierson/SPARK-13444.
This commit is contained in:
parent
3fa6491be6
commit
6f8e835c68
|
@ -103,6 +103,13 @@ final class QuantileDiscretizer(override val uid: String)
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
|
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Minimum number of samples required for finding splits, regardless of number of bins. If
|
||||||
|
* the dataset has fewer rows than this value, the entire dataset will be used.
|
||||||
|
*/
|
||||||
|
private[spark] val minSamplesRequired: Int = 10000
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sampling from the given dataset to collect quantile statistics.
|
* Sampling from the given dataset to collect quantile statistics.
|
||||||
*/
|
*/
|
||||||
|
@ -110,8 +117,8 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
|
||||||
val totalSamples = dataset.count()
|
val totalSamples = dataset.count()
|
||||||
require(totalSamples > 0,
|
require(totalSamples > 0,
|
||||||
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
|
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
|
||||||
val requiredSamples = math.max(numBins * numBins, 10000)
|
val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
|
||||||
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
|
val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0)
|
||||||
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
|
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,6 +71,26 @@ class QuantileDiscretizerSuite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Test splits on dataset larger than minSamplesRequired") {
|
||||||
|
val sqlCtx = SQLContext.getOrCreate(sc)
|
||||||
|
import sqlCtx.implicits._
|
||||||
|
|
||||||
|
val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
|
||||||
|
val numBuckets = 5
|
||||||
|
val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
|
||||||
|
val discretizer = new QuantileDiscretizer()
|
||||||
|
.setInputCol("input")
|
||||||
|
.setOutputCol("result")
|
||||||
|
.setNumBuckets(numBuckets)
|
||||||
|
.setSeed(1)
|
||||||
|
|
||||||
|
val result = discretizer.fit(df).transform(df)
|
||||||
|
val observedNumBuckets = result.select("result").distinct.count
|
||||||
|
|
||||||
|
assert(observedNumBuckets === numBuckets,
|
||||||
|
"Observed number of buckets does not equal expected number of buckets.")
|
||||||
|
}
|
||||||
|
|
||||||
test("read/write") {
|
test("read/write") {
|
||||||
val t = new QuantileDiscretizer()
|
val t = new QuantileDiscretizer()
|
||||||
.setInputCol("myInputCol")
|
.setInputCol("myInputCol")
|
||||||
|
|
Loading…
Reference in a new issue