[SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python
## What changes were proposed in this pull request? This PR make `DataFrame.sample(...)` can omit `withReplacement` defaulting `False`, consistently with equivalent Scala / Java API. In short, the following examples are allowed: ```python >>> df = spark.range(10) >>> df.sample(0.5).count() 7 >>> df.sample(fraction=0.5).count() 3 >>> df.sample(0.5, seed=42).count() 5 >>> df.sample(fraction=0.5, seed=42).count() 5 ``` In addition, this PR also adds some type checking logics as below: ```python >>> df = spark.range(10) >>> df.sample().count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got []. >>> df.sample(True).count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>]. >>> df.sample(42).count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'int'>]. >>> df.sample(fraction=False, seed="a").count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>, <type 'str'>]. >>> df.sample(seed=[1]).count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'list'>]. >>> df.sample(withReplacement="a", fraction=0.5, seed=1) ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'str'>, <type 'float'>, <type 'int'>]. ``` ## How was this patch tested? Manually tested, unit tests added in doc tests and manually checked the built documentation for Python. Author: hyukjinkwon <gurwls223@gmail.com> Closes #18999 from HyukjinKwon/SPARK-21779.
This commit is contained in:
parent
f5e10a34e6
commit
5cd8ea99f0
|
@ -659,19 +659,69 @@ class DataFrame(object):
|
|||
return DataFrame(self._jdf.distinct(), self.sql_ctx)
|
||||
|
||||
@since(1.3)
|
||||
def sample(self, withReplacement, fraction, seed=None):
|
||||
def sample(self, withReplacement=None, fraction=None, seed=None):
|
||||
"""Returns a sampled subset of this :class:`DataFrame`.
|
||||
|
||||
:param withReplacement: Sample with replacement or not (default False).
|
||||
:param fraction: Fraction of rows to generate, range [0.0, 1.0].
|
||||
:param seed: Seed for sampling (default a random seed).
|
||||
|
||||
.. note:: This is not guaranteed to provide exactly the fraction specified of the total
|
||||
count of the given :class:`DataFrame`.
|
||||
|
||||
>>> df.sample(False, 0.5, 42).count()
|
||||
2
|
||||
.. note:: `fraction` is required and, `withReplacement` and `seed` are optional.
|
||||
|
||||
>>> df = spark.range(10)
|
||||
>>> df.sample(0.5, 3).count()
|
||||
4
|
||||
>>> df.sample(fraction=0.5, seed=3).count()
|
||||
4
|
||||
>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
|
||||
1
|
||||
>>> df.sample(1.0).count()
|
||||
10
|
||||
>>> df.sample(fraction=1.0).count()
|
||||
10
|
||||
>>> df.sample(False, fraction=1.0).count()
|
||||
10
|
||||
"""
|
||||
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
|
||||
seed = seed if seed is not None else random.randint(0, sys.maxsize)
|
||||
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
|
||||
return DataFrame(rdd, self.sql_ctx)
|
||||
|
||||
# For the cases below:
|
||||
# sample(True, 0.5 [, seed])
|
||||
# sample(True, fraction=0.5 [, seed])
|
||||
# sample(withReplacement=False, fraction=0.5 [, seed])
|
||||
is_withReplacement_set = \
|
||||
type(withReplacement) == bool and isinstance(fraction, float)
|
||||
|
||||
# For the case below:
|
||||
# sample(faction=0.5 [, seed])
|
||||
is_withReplacement_omitted_kwargs = \
|
||||
withReplacement is None and isinstance(fraction, float)
|
||||
|
||||
# For the case below:
|
||||
# sample(0.5 [, seed])
|
||||
is_withReplacement_omitted_args = isinstance(withReplacement, float)
|
||||
|
||||
if not (is_withReplacement_set
|
||||
or is_withReplacement_omitted_kwargs
|
||||
or is_withReplacement_omitted_args):
|
||||
argtypes = [
|
||||
str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None]
|
||||
raise TypeError(
|
||||
"withReplacement (optional), fraction (required) and seed (optional)"
|
||||
" should be a bool, float and number; however, "
|
||||
"got [%s]." % ", ".join(argtypes))
|
||||
|
||||
if is_withReplacement_omitted_args:
|
||||
if fraction is not None:
|
||||
seed = fraction
|
||||
fraction = withReplacement
|
||||
withReplacement = None
|
||||
|
||||
seed = long(seed) if seed is not None else None
|
||||
args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
|
||||
jdf = self._jdf.sample(*args)
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@since(1.5)
|
||||
def sampleBy(self, col, fractions, seed=None):
|
||||
|
|
|
@ -2108,6 +2108,24 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
|
||||
self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
|
||||
|
||||
def test_sample(self):
|
||||
self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"should be a bool, float and number",
|
||||
lambda: self.spark.range(1).sample())
|
||||
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
lambda: self.spark.range(1).sample("a"))
|
||||
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
lambda: self.spark.range(1).sample(seed="abc"))
|
||||
|
||||
self.assertRaises(
|
||||
IllegalArgumentException,
|
||||
lambda: self.spark.range(1).sample(-1.0))
|
||||
|
||||
def test_toDF_with_schema_string(self):
|
||||
data = [Row(key=i, value=str(i)) for i in range(100)]
|
||||
rdd = self.sc.parallelize(data, 5)
|
||||
|
|
|
@ -1867,7 +1867,8 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement).
|
||||
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
|
||||
* using a random seed.
|
||||
*
|
||||
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
|
||||
*
|
||||
|
|
Loading…
Reference in a new issue