[SPARK-13444][MLLIB] QuantileDiscretizer chooses bad splits on large DataFrames

## What changes were proposed in this pull request?

Change line 113 of QuantileDiscretizer.scala to

`val requiredSamples = math.max(numBins * numBins, 10000.0)`

so that `requiredSamples` is a `Double`.  This will fix the division in line 114 which currently results in zero if `requiredSamples < dataset.count`

## How was the this patch tested?
Manual tests.  I was having a problems using QuantileDiscretizer with my a dataset and after making this change QuantileDiscretizer behaves as expected.

Author: Oliver Pierson <ocp@gatech.edu>
Author: Oliver Pierson <opierson@umd.edu>

Closes #11319 from oliverpierson/SPARK-13444.
This commit is contained in:
Oliver Pierson 2016-02-25 13:24:46 +00:00 committed by Sean Owen
parent 3fa6491be6
commit 6f8e835c68
2 changed files with 29 additions and 2 deletions

View file

@ -103,6 +103,13 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0") @Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
/**
* Minimum number of samples required for finding splits, regardless of number of bins. If
* the dataset has fewer rows than this value, the entire dataset will be used.
*/
private[spark] val minSamplesRequired: Int = 10000
/** /**
* Sampling from the given dataset to collect quantile statistics. * Sampling from the given dataset to collect quantile statistics.
*/ */
@ -110,8 +117,8 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
val totalSamples = dataset.count() val totalSamples = dataset.count()
require(totalSamples > 0, require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.") "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, 10000) val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
val fraction = math.min(requiredSamples / dataset.count(), 1.0) val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0)
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
} }

View file

@ -71,6 +71,26 @@ class QuantileDiscretizerSuite
} }
} }
test("Test splits on dataset larger than minSamplesRequired") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
val numBuckets = 5
val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)
.setSeed(1)
val result = discretizer.fit(df).transform(df)
val observedNumBuckets = result.select("result").distinct.count
assert(observedNumBuckets === numBuckets,
"Observed number of buckets does not equal expected number of buckets.")
}
test("read/write") { test("read/write") {
val t = new QuantileDiscretizer() val t = new QuantileDiscretizer()
.setInputCol("myInputCol") .setInputCol("myInputCol")