[SPARK-19165][PYTHON][SQL] PySpark APIs using columns as arguments should validate input types for column
## What changes were proposed in this pull request?
While preparing to take over https://github.com/apache/spark/pull/16537, I realised a (I think) better approach to make the exception handling in one point.
This PR proposes to fix `_to_java_column` in `pyspark.sql.column`, which most of functions in `functions.py` and some other APIs use. This `_to_java_column` basically looks not working with other types than `pyspark.sql.column.Column` or string (`str` and `unicode`).
If this is not `Column`, then it calls `_create_column_from_name` which calls `functions.col` within JVM:
42b9eda80e/sql/core/src/main/scala/org/apache/spark/sql/functions.scala (L76)
And it looks we only have `String` one with `col`.
So, these should work:
```python
>>> from pyspark.sql.column import _to_java_column, Column
>>> _to_java_column("a")
JavaObject id=o28
>>> _to_java_column(u"a")
JavaObject id=o29
>>> _to_java_column(spark.range(1).id)
JavaObject id=o33
```
whereas these do not:
```python
>>> _to_java_column(1)
```
```
...
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.lang.Integer]) does not exist
...
```
```python
>>> _to_java_column([])
```
```
...
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist
...
```
```python
>>> class A(): pass
>>> _to_java_column(A())
```
```
...
AttributeError: 'A' object has no attribute '_get_object_id'
```
Meaning most of functions using `_to_java_column` such as `udf` or `to_json` or some other APIs throw an exception as below:
```python
>>> from pyspark.sql.functions import udf
>>> udf(lambda x: x)(None)
```
```
...
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.col.
: java.lang.NullPointerException
...
```
```python
>>> from pyspark.sql.functions import to_json
>>> to_json(None)
```
```
...
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.col.
: java.lang.NullPointerException
...
```
**After this PR**:
```python
>>> from pyspark.sql.functions import udf
>>> udf(lambda x: x)(None)
...
```
```
TypeError: Invalid argument, not a string or column: None of type <type 'NoneType'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' functions.
```
```python
>>> from pyspark.sql.functions import to_json
>>> to_json(None)
```
```
...
TypeError: Invalid argument, not a string or column: None of type <type 'NoneType'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' functions.
```
## How was this patch tested?
Unit tests added in `python/pyspark/sql/tests.py` and manual tests.
Author: hyukjinkwon <gurwls223@gmail.com>
Author: zero323 <zero323@users.noreply.github.com>
Closes #19027 from HyukjinKwon/SPARK-19165.
This commit is contained in:
parent
95713eb4f2
commit
dc5d34d8dc
|
@ -44,8 +44,14 @@ def _create_column_from_name(name):
|
|||
def _to_java_column(col):
|
||||
if isinstance(col, Column):
|
||||
jcol = col._jc
|
||||
else:
|
||||
elif isinstance(col, basestring):
|
||||
jcol = _create_column_from_name(col)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Invalid argument, not a string or column: "
|
||||
"{0} of type {1}. "
|
||||
"For column literals, use 'lit', 'array', 'struct' or 'create_map' "
|
||||
"function.".format(col, type(col)))
|
||||
return jcol
|
||||
|
||||
|
||||
|
|
|
@ -704,6 +704,31 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
self.assertEqual(f, f_.func)
|
||||
self.assertEqual(return_type, f_.returnType)
|
||||
|
||||
def test_validate_column_types(self):
|
||||
from pyspark.sql.functions import udf, to_json
|
||||
from pyspark.sql.column import _to_java_column
|
||||
|
||||
self.assertTrue("Column" in _to_java_column("a").getClass().toString())
|
||||
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
|
||||
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"Invalid argument, not a string or column",
|
||||
lambda: _to_java_column(1))
|
||||
|
||||
class A():
|
||||
pass
|
||||
|
||||
self.assertRaises(TypeError, lambda: _to_java_column(A()))
|
||||
self.assertRaises(TypeError, lambda: _to_java_column([]))
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"Invalid argument, not a string or column",
|
||||
lambda: udf(lambda x: x)(None))
|
||||
self.assertRaises(TypeError, lambda: to_json(1))
|
||||
|
||||
def test_basic_functions(self):
|
||||
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
||||
df = self.spark.read.json(rdd)
|
||||
|
|
Loading…
Reference in a new issue