[SPARK-9691] [SQL] PySpark SQL rand function treats seed 0 as no seed
https://issues.apache.org/jira/browse/SPARK-9691 jkbradley rxin Author: Yin Huai <yhuai@databricks.com> Closes #7999 from yhuai/pythonRand and squashes the following commits: 4187e0c [Yin Huai] Regression test. a985ef9 [Yin Huai] Use "if seed is not None" instead "if seed" because "if seed" returns false when seed is 0.
This commit is contained in:
parent
681e3024b6
commit
baf4587a56
|
@ -268,7 +268,7 @@ def rand(seed=None):
|
|||
"""Generates a random column with i.i.d. samples from U[0.0, 1.0].
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
if seed:
|
||||
if seed is not None:
|
||||
jc = sc._jvm.functions.rand(seed)
|
||||
else:
|
||||
jc = sc._jvm.functions.rand()
|
||||
|
@ -280,7 +280,7 @@ def randn(seed=None):
|
|||
"""Generates a column with i.i.d. samples from the standard normal distribution.
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
if seed:
|
||||
if seed is not None:
|
||||
jc = sc._jvm.functions.randn(seed)
|
||||
else:
|
||||
jc = sc._jvm.functions.randn()
|
||||
|
|
|
@ -629,6 +629,16 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
for row in rndn:
|
||||
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
|
||||
|
||||
# If the specified seed is 0, we should use it.
|
||||
# https://issues.apache.org/jira/browse/SPARK-9691
|
||||
rnd1 = df.select('key', functions.rand(0)).collect()
|
||||
rnd2 = df.select('key', functions.rand(0)).collect()
|
||||
self.assertEqual(sorted(rnd1), sorted(rnd2))
|
||||
|
||||
rndn1 = df.select('key', functions.randn(0)).collect()
|
||||
rndn2 = df.select('key', functions.randn(0)).collect()
|
||||
self.assertEqual(sorted(rndn1), sorted(rndn2))
|
||||
|
||||
def test_between_function(self):
|
||||
df = self.sc.parallelize([
|
||||
Row(a=1, b=2, c=3),
|
||||
|
|
Loading…
Reference in a new issue