[SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python

## What changes were proposed in this pull request?

This PR make `DataFrame.sample(...)` can omit `withReplacement` defaulting `False`, consistently with equivalent Scala / Java API.

In short, the following examples are allowed:

```python
>>> df = spark.range(10)
>>> df.sample(0.5).count()
7
>>> df.sample(fraction=0.5).count()
3
>>> df.sample(0.5, seed=42).count()
5
>>> df.sample(fraction=0.5, seed=42).count()
5
```

In addition, this PR also adds some type checking logics as below:

```python
>>> df = spark.range(10)
>>> df.sample().count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [].
>>> df.sample(True).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>].
>>> df.sample(42).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'int'>].
>>> df.sample(fraction=False, seed="a").count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>, <type 'str'>].
>>> df.sample(seed=[1]).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'list'>].
>>> df.sample(withReplacement="a", fraction=0.5, seed=1)
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'str'>, <type 'float'>, <type 'int'>].
```

## How was this patch tested?

Manually tested, unit tests added in doc tests and manually checked the built documentation for Python.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #18999 from HyukjinKwon/SPARK-21779.
This commit is contained in:
hyukjinkwon 2017-09-01 13:01:23 +09:00
parent f5e10a34e6
commit 5cd8ea99f0
3 changed files with 77 additions and 8 deletions

View file

@ -659,19 +659,69 @@ class DataFrame(object):
return DataFrame(self._jdf.distinct(), self.sql_ctx)
@since(1.3)
def sample(self, withReplacement, fraction, seed=None):
def sample(self, withReplacement=None, fraction=None, seed=None):
"""Returns a sampled subset of this :class:`DataFrame`.
:param withReplacement: Sample with replacement or not (default False).
:param fraction: Fraction of rows to generate, range [0.0, 1.0].
:param seed: Seed for sampling (default a random seed).
.. note:: This is not guaranteed to provide exactly the fraction specified of the total
count of the given :class:`DataFrame`.
>>> df.sample(False, 0.5, 42).count()
2
.. note:: `fraction` is required and, `withReplacement` and `seed` are optional.
>>> df = spark.range(10)
>>> df.sample(0.5, 3).count()
4
>>> df.sample(fraction=0.5, seed=3).count()
4
>>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
1
>>> df.sample(1.0).count()
10
>>> df.sample(fraction=1.0).count()
10
>>> df.sample(False, fraction=1.0).count()
10
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
# sample(withReplacement=False, fraction=0.5 [, seed])
is_withReplacement_set = \
type(withReplacement) == bool and isinstance(fraction, float)
# For the case below:
# sample(faction=0.5 [, seed])
is_withReplacement_omitted_kwargs = \
withReplacement is None and isinstance(fraction, float)
# For the case below:
# sample(0.5 [, seed])
is_withReplacement_omitted_args = isinstance(withReplacement, float)
if not (is_withReplacement_set
or is_withReplacement_omitted_kwargs
or is_withReplacement_omitted_args):
argtypes = [
str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None]
raise TypeError(
"withReplacement (optional), fraction (required) and seed (optional)"
" should be a bool, float and number; however, "
"got [%s]." % ", ".join(argtypes))
if is_withReplacement_omitted_args:
if fraction is not None:
seed = fraction
fraction = withReplacement
withReplacement = None
seed = long(seed) if seed is not None else None
args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
jdf = self._jdf.sample(*args)
return DataFrame(jdf, self.sql_ctx)
@since(1.5)
def sampleBy(self, col, fractions, seed=None):

View file

@ -2108,6 +2108,24 @@ class SQLTests(ReusedPySparkTestCase):
plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
def test_sample(self):
self.assertRaisesRegexp(
TypeError,
"should be a bool, float and number",
lambda: self.spark.range(1).sample())
self.assertRaises(
TypeError,
lambda: self.spark.range(1).sample("a"))
self.assertRaises(
TypeError,
lambda: self.spark.range(1).sample(seed="abc"))
self.assertRaises(
IllegalArgumentException,
lambda: self.spark.range(1).sample(-1.0))
def test_toDF_with_schema_string(self):
data = [Row(key=i, value=str(i)) for i in range(100)]
rdd = self.sc.parallelize(data, 5)

View file

@ -1867,7 +1867,8 @@ class Dataset[T] private[sql](
}
/**
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement).
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
* using a random seed.
*
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
*